In [1]:
from augmented_modules import ResnetBlock, CondInstanceNorm, TwoInputSequential, CINResnetBlock, InstanceNorm2d
from augmented_utils import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'On {device}')
Tensor  = torch.cuda.FloatTensor


Figure directory exists.
On cuda:0


In [2]:
#for quicker debugging purposes
class CustomDataset(Dataset):
    def __init__(self, num_samples):
        # Generate synthetic data (you can replace this with your actual data)
        self.data1 = torch.randn(num_samples, 3, 224, 224)
        self.data2 = torch.randn(num_samples, 3, 224, 224)
        self.labels1 = torch.arange(num_samples)  # Dummy labels (you can modify this)
        self.labels2 = torch.arange(num_samples)

    def __len__(self):
        return len(self.data1)

    def __getitem__(self, idx):
        sample = {
              'D img': self.data1[idx],
              'D pm' : self.labels1[idx],
              'L img': self.data2[idx],
              'L pm' : self.labels2[idx]
        }
        return sample

# Create an instance of your custom dataset
num_samples = 3000  # Adjust as needed
custom_dataset = CustomDataset(num_samples)

# Create a DataLoader
batch_size = 5
train_dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

In [3]:
# norm_layer_ = functools.partial(nn.InstanceNorm2d, affine=False)
# mod = Latent_Encoder(nlatent=3, input_nc=3, nef=32, norm_layer=norm_layer_)
# ex = torch.rand((5,3,224,224))

# g = mod(ex)

In [8]:
class Pixel_Level_Augmented_CycleGAN():
    def __init__(self):
        super(Pixel_Level_Augmented_CycleGAN,self).__init__()
        
        norm_layer_C = CondInstanceNorm
        norm_layer_ = functools.partial(nn.InstanceNorm2d, affine=False)
        
        self.G_ST = CINResnetGenerator(nlatent=16, input_nc=3, output_nc=3, ngf=64, norm_layer=norm_layer_C,
                 use_dropout=False, n_blocks=3, gpu_ids=[], padding_type='reflect')
        
        self.G_TS = CINResnetGenerator(nlatent=16, input_nc=3, output_nc=3, ngf=64, norm_layer=norm_layer_C,
                 use_dropout=False, n_blocks=3, gpu_ids=[], padding_type='reflect')

        self.D_T = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=4, norm_layer=norm_layer_)
        self.D_S = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=4, norm_layer=norm_layer_)

        self.D_zs = Discriminator_Latent(nlatent=16, ndf=64)
        self.D_zt = Discriminator_Latent(nlatent=16, ndf=64)
        
        self.E_S = Latent_Encoder(nlatent=16, input_nc=6, nef=32, norm_layer=norm_layer_)
        self.E_T = Latent_Encoder(nlatent=16, input_nc=6, nef=32, norm_layer=norm_layer_)

        self.G_ST.to(device)
        self.G_TS.to(device)
        self.E_S.to(device)
        self.E_T.to(device)
        self.D_S.to(device)
        self.D_T.to(device)
        self.D_z.to(device)


        self.ganloss = GANLoss().to(device)       
        self.cycleloss = torch.nn.L1Loss().to(device)               #difference between reconstructed img and original
        self.mseloss = torch.nn.MSELoss().to(device)       #difference between domain classifications between input img and generator output
        self.identityloss = torch.nn.L1Loss().to(device)
        
        self.optimizer_GS = torch.optim.Adam(itertools.chain(self.G_ST.parameters(), self.E_S.parameters()), 
                                            lr=2e-5, betas=(0.5,0.999))
        
        self.optimizer_GT = torch.optim.Adam(itertools.chain(self.G_TS.parameters(), self.E_T.parameters()), 
                                            lr=2e-5, betas=(0.5,0.999))
        
        self.optimizer_DS = torch.optim.Adam(itertools.chain(self.D_S.parameters(),self.D_z.parameters())
                                             , lr = 1e-5, betas = (0.5,0.9))
        self.optimizer_DT = torch.optim.Adam(itertools.chain(self.D_T.parameters(),self.D_z.parameters())
                                             , lr = 1e-5, betas = (0.5,0.9))


        self.G_ST.apply(weights_init_normal)
        self.G_TS.apply(weights_init_normal)
        self.D_S.apply(weights_init_normal)
        self.D_T.apply(weights_init_normal)
#         self.D_z.apply(weights_init_normal)

        print('initialized')
        
    def data_input(self, batch):
        self.real_S = batch['D img'].type(Tensor)
        self.real_T = batch['L img'].type(Tensor)
        self.real_lbl = batch['D pm'].type(Tensor).float()
        self.prior_S = self.real_S.data.new(self.real_S.size(0), 16, 1, 1).normal_(0, 1)
        self.prior_T = self.real_T.data.new(self.real_T.size(0), 16, 1, 1).normal_(0, 1)
                  
    def forward_pass(self):
        self.t_tilde = self.G_ST(self.real_S, self.prior_S) #shape checks out [BS, CH, H, W] ~ [5, 3, 224, 224] 
        self.s_tilde = self.G_TS(self.real_T, self.prior_T) #shape checks out [BS, CH, H, W] ~ [5, 3, 224, 224] 
        
        
        self.s_mu, self.s_logvar = self.E_S(torch.cat((self.t_tilde, self.real_S), 1))
        
        self.t_mu, self.t_logvar= self.E_T(torch.cat((self.s_tilde,self.real_T),1))

        self.zeta_tilde_s = reparameterize(self.s_mu, self.s_logvar)
        self.zeta_tilde_t = reparameterize(self.t_mu, self.t_logvar)
        
        self.s_recon = self.G_TS(self.t_tilde, self.zeta_tilde_s)
        self.t_recon = self.G_TS(self.s_tilde, self.zeta_tilde_t)
        
        self.zeta_t_prime = self.E_T(torch.cat((self.real_S, self.t_tilde),1))
        self.zeta_s_primt = self.E_S(torch.cat((self.real_T, self.s_tilde),1))
        
        
    def backward_D(self, netD, real, fake):
         
        # Real
        pred_real = netD(real)
        loss_D_real = self.ganloss(pred_real, True)
        
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.ganloss(pred_fake, False)
        

        loss_D = (loss_D_real + loss_D_fake) * 0.5

        return loss_D
    
    def backward_Discriminators(self): #L^T_{GAN}(G_{ST}, D_T)
        self.loss_DS = self.backward_D(self.D_T, self.real_S, self.t_tilde.detach())
        
        self.loss_DT = self.backward_D(self.D_S, self.real_T, self.s_tilde.detach())
        
        self.loss_DZ = self.backward_D(self.D_z, self.prior_S, self.zeta_tilde_s.detach())
        
        self.loss_D = self.loss_DS + self.loss_DT +self.loss_DZ
        
        self.loss_D.backward()
        
    def optimize(self):
        self.forward_pass()
        
        set_requires_grad([self.D_S, self.D_T, self.D_z], requires_grad=True)
        set_requires_grad([self.G_ST, self.G_TS, self.E], requires_grad=False)
        self.optimizer_DS.zero_grad()
        self.optimizer_DT.zero_grad()
        
        self.backward_Discriminators()
        self.optimizer_DS.step()


In [9]:
history = {'epoch':[],'G_loss': [], 'DS_loss':[], 'DT_loss':[], 'batch':[]}
torch.cuda.empty_cache()

model = Pixel_Level_Augmented_CycleGAN()
best_gen_loss = 1e6
best_DT_loss = 1e6
best_DS_loss = 1e6

n_epochs = 15

for e in range(n_epochs):
    for i, batch in tqdm(enumerate(train_dataloader)):
        model.data_input(batch)
        model.optimize()
        
      

AttributeError: 'Pixel_Level_Augmented_CycleGAN' object has no attribute 'D_z'