In [6]:
from utils import *
import pandas as pd
os.getcwd()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'On {device}')

On cuda:0


In [7]:
root = '/datacommons/carlsonlab/srs108/archive/'
transforms_ = [
    transforms.Resize(int(img_height*1.12), Image.BICUBIC),
    transforms.RandomCrop((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(
    ImageDataset(root, transforms_=transforms_, unaligned=False, mode='train'),
    batch_size=3, # 1
    shuffle=True,)

val_dataloader = DataLoader(
    ImageDataset(root, transforms_=transforms_, unaligned=False, mode='test'),
    batch_size=1,
    shuffle=True,
)

In [3]:
# x = ImageDataset(root, transforms_=transforms_, unaligned=True, mode='train'),


In [4]:
# v = next(iter(dataloader))
# plt.imshow(v['B'].squeeze(dim=0).T)

In [16]:
class Basic_CycleGAN():
    def __init__(self):
        super(Basic_CycleGAN,self).__init__()
        
        norm_layer_G = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        norm_layer_D = functools.partial(nn.InstanceNorm2d, affine=False)
        self.G_ST = ResnetGenerator(3, 3, 64, norm_layer=norm_layer_G, use_dropout=False, n_blocks=9) #9
        self.G_TS = ResnetGenerator(3, 3, 64, norm_layer=norm_layer_G, use_dropout=False, n_blocks=9) #9

        self.D_T = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=4, norm_layer=norm_layer_D)
        self.D_S = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=4, norm_layer=norm_layer_D)

    
        self.G_ST.to(device)
        self.G_TS.to(device)
        self.D_S.to(device)
        self.D_T.to(device)

        self.ganloss = GANLoss(gan_mode='vanilla').to(device)       
        self.cycleloss = torch.nn.L1Loss().to(device)      #difference between reconstructed img and original
        self.identityloss = torch.nn.L1Loss().to(device)
        
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.G_ST.parameters(), self.G_TS.parameters()), lr=2e-5, betas=(0.5,0.999))
        self.optimizer_D = torch.optim.Adam(self.D_S.parameters(), lr = 1e-4, betas = (0.5,0.999))
#         self.optimizer_DT = torch.optim.Adam(self.D_T.parameters(), lr = 1e-5, betas = (0.5,0.999))


        self.G_ST.apply(weights_init_normal)
        self.G_TS.apply(weights_init_normal)
        self.D_S.apply(weights_init_normal)
        self.D_T.apply(weights_init_normal)

        print('initialized')
        
    def data_input(self, batch):
        self.real_S = batch['A'].type(Tensor)
        self.real_T = batch['B'].type(Tensor)
        
    def sample_images(self, dataloader, epochs, iters, save = False):
        source = next(iter(dataloader))
        self.G_ST.eval()
        self.G_TS.eval()
        real_source = source['A'].type(Tensor) 
        fake_target = self.G_ST(real_source).detach()
        real_target = source['B'].type(Tensor)
        fake_source = self.G_TS(real_target).detach()

        recons = self.G_TS(fake_target).detach()
        recont = self.G_ST(fake_source).detach() 

        real_S = make_grid(real_source, nrow=5, normalize=True, scale_each=True, padding=1)
        fake_T = make_grid(fake_target, nrow=5, normalize=True, scale_each=True, padding=1)
        reconS = make_grid(recons, nrow=5, normalize=True, scale_each=True, padding=1)
        real_T = make_grid(real_target, nrow=5, normalize=True, scale_each=True, padding=1)
        fake_S = make_grid(fake_source, nrow=5, normalize=True, scale_each=True, padding=1)
        reconT = make_grid(recont, nrow=5, normalize=True, scale_each=True, padding=1)
        # Arange images along y-axis    
        image_grid = torch.cat((real_S, fake_T, reconS, real_T, fake_S, reconT), 2)
        plt.imshow(image_grid.cpu().permute(1,2,0))
        plt.title('Real Source | Fake Target | Recon Source | Real Target | Fake Source | Recon Target')
        plt.axis('off')
        plt.gcf().set_size_inches(10, 6)
        if save:
            plt.savefig(os.path.join('Figure_PDFs', f'epoch_{str(e+1)}_iter{str(i+1)}.png'), bbox_inches='tight', pad_inches=0, facecolor='white')
        plt.show();
        
    def forward_pass(self):
        self.fake_t = self.G_ST(self.real_S)
        self.fake_s = self.G_TS(self.real_T)             
        
        self.recov_s = self.G_TS(self.fake_t)
        self.recov_t = self.G_ST(self.fake_s)
        
    def backward_D(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.ganloss(pred_real, True)
        
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.ganloss(pred_fake, False)
        
        loss_D = (loss_D_real + loss_D_fake)/2

        return loss_D

    def backward_DS(self):
        self.DS_Loss = self.backward_D(self.D_S, self.real_T ,self.fake_t)
        self.DS_Loss.backward()
        
    def backward_DT(self):
        self.DT_Loss = self.backward_D(self.D_T, self.real_S, self.fake_s)
        self.DT_Loss.backward()
        
    def backward_G(self):

        idt_S = self.G_ST(self.real_T)  #G_ST(t) ~ s
        idt_T = self.G_TS(self.real_S)  #G_TS(G_ST(s))
        
        loss_idt_S = self.identityloss(idt_S, self.real_T)
        loss_idt_T = self.identityloss(idt_T, self.real_S)
        self.loss_idt = (loss_idt_S + loss_idt_T)/2

        loss_G_S = self.ganloss(self.D_S(self.fake_t), True)       #L(D_S(G_ST(S)))
        loss_G_T = self.ganloss(self.D_T(self.fake_s), True)       #L(D_T(G_TS(T)))    
        self.loss_GAN = (loss_G_S + loss_G_T)/2
        
        loss_cycle_S   = self.cycleloss(self.recov_s, self.real_S)   # Lcyc(G_TS(G_ST(S)), S) * λ 
        loss_cycle_T   = self.cycleloss(self.recov_t, self.real_T)  # Lcyc(G_ST(G_TS(T)), T) * λ 
        self.loss_cycle = (loss_cycle_S + loss_cycle_T)/2
        
        self.loss_G = self.loss_GAN + (self.loss_cycle*10) + (self.loss_idt*5)
        
        self.loss_G.backward()

    def optimize(self):
        self.forward_pass()
        
# ----------------------------------------------------------------------------
# Train Discriminators D_T and D_S;
# -----------------------------------------------------------------------------------
        self.optimizer_D.zero_grad()

        set_requires_grad([self.D_S, self.D_T],requires_grad=True)
        self.backward_DS()
        self.backward_DT()
        
        self.optimizer_D.step()


# -------------------------------------------------------------------------------------------------------
# Train Generators G_ST and G_TS L_CYC;
# -------------------------------------------------------------------------------------------------------
        self.optimizer_G.zero_grad()
        set_requires_grad([self.G_ST, self.G_TS],requires_grad=True)
        self.backward_G()
        self.optimizer_G.step()
        
        return self.loss_G, self.DS_Loss, self.DT_Loss, self.D_S, self.D_T, self.G_ST, self.G_TS

In [1]:
# history = {'epoch':[],'G_loss': [], 'DS_loss':[], 'DT_loss':[], 'batch':[]}

# model = Basic_CycleGAN()
# best_gen_loss = 1e6
# best_DT_loss = 1e6
# best_DS_loss = 1e6
# Tensor  = torch.cuda.FloatTensor

# n_epochs = 100

# for e in range(n_epochs):
#     for i, batch in tqdm(enumerate(dataloader)):
#         model.data_input(batch)
#         G_loss, DS_loss, DT_loss, DS, DT, GST, GTS = model.optimize()

        
#         if (DT_loss+DS_loss) < best_DT_loss:
#             best_DT_loss = (DT_loss+DS_loss)
#             torch.save({'D_T': DT.state_dict(), 'D_S': DS.state_dict()}, 'best_D.pt')
            
#         if G_loss < best_gen_loss:
#             best_gen_loss = G_loss
#             torch.save({ 'G_ST': GST.state_dict(),'G_TS': GTS.state_dict()}, 'best_G.pt')
           
        
#         if (i+1) % 50 == 0:
#             with torch.no_grad():
#                 model.sample_images(val_dataloader, e, i, save=False)
                
#         history['G_loss'].append(G_loss.item())
#         history['DS_loss'].append(DS_loss.item())
#         history['DT_loss'].append(DT_loss.item())
#         history['batch'].append(i+1)
#         history['epoch'].append(e+1)
        
        
#     print(f"Epoch {e + 1}/{n_epochs}\n\
#         [G Loss: {round(G_loss.item(), 4)}]\t[D Loss: {round(DS_loss.item()+ DT_loss.item(),4)}]")

In [None]:
df = pd.DataFrame(history)
df.to_csv('history.csv', index=False)

In [None]:
# def gen_dis_loss(genloss, disloss, epochs, save = True, fig_name=''):
#     epoch = range(epochs)
#     fig, ax = plt.subplots(1,1, figsize = (6,6))   
#     ax.plot(epoch, genloss, color='b', linewidth=0.5, label='Generator')
#     ax.plot(epoch, disloss, color='r', linewidth=0.5, label='Discriminator')
#     ax.set_xlabel('Iters')
#     ax.set_ylabel('Loss')
#     ax.set_title('Generator and Discriminator Loss')
#     ax.legend()
#     plt.show()
#     if save==True:
#         fig.savefig(PROJECT_ROOT_DIR+'/'+PROJECT_SAVE_DIR+'/'+fig_name+'.png', transparent=False, facecolor='white', bbox_inches='tight')

In [None]:
# gen_dis_loss(G_loss, D_loss, len(D_loss), save = False, fig_name='gdloss')