In [6]:
from mtl_utils import *
Tensor = torch.cuda.FloatTensor
cuda = True if torch.cuda.is_available() else False
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler

Using cuda device


Changing the size of the image requries changing script code for the discriminator in the linear output line.

In [2]:
x1 = Temp_Dataset('/datacommons/carlsonlab/srs108/old/ol/Mumbai.pkl', transform=True) #source
x2 = Temp_Dataset('/datacommons/carlsonlab/srs108/old/ol/Shanghai.pkl', transform = True) #target

In [3]:
source = DataLoader(x1, batch_size=4)
target = DataLoader(x2, batch_size=4)

In [4]:
#taken directly from cycleGAN Github page.
class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
        """Construct a Resnet-based generator
        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        """Standard forward"""
        return self.model(input)


class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block
        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block.
        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not
        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

In [7]:
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
net = ResnetGenerator(3, 3, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9)

TypeError: __init__() got an unexpected keyword argument 'padding_type'

In [4]:
#Initializing networks
lr = 0.00002
Gen_S_T = Generator(input_dim=3, num_filter=64, num_res=6)
Gen_T_S = Generator(input_dim=3, num_filter = 64, num_res=6)

Dis_T = Discriminator(input_dim = 3, num_filter=64)
Dis_S = Discriminator(input_dim=3, num_filter=64)

Dis_Feat = Discriminator_Task(input_dim=1, hidden_dim= 500)

regressor_S = Regressor_Task()
regressor_T = Regressor_Task()

#Initializing weights
Gen_S_T.apply(weights_init_normal)
Gen_T_S.apply(weights_init_normal)
Dis_T.apply(weights_init_normal)
Dis_S.apply(weights_init_normal)
Dis_Feat.apply(weights_init_normal)
print()

#move to GPU
Gen_S_T.cuda()
Gen_T_S.cuda()
Dis_T.cuda()
Dis_S.cuda()
Dis_Feat.cuda()
regressor_S.cuda()
regressor_T.cuda()
print()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(Gen_S_T.parameters(), Gen_T_S.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(Dis_T.parameters(), Dis_S.parameters()), lr=lr, betas=(0.5, 0.999))

optimizer_D_feat = torch.optim.Adam(Dis_Feat.parameters(), lr= lr/10, betas = (0.5,0.999))

optimizer_R_S = torch.optim.Adam(regressor_S.parameters(), lr=lr, betas = (0.5, 0.999))
optimizer_R_T = torch.optim.Adam(regressor_T.parameters(), lr=lr, betas = (0.5, 0.999))

#Losses
f_loss = torch.nn.MSELoss() #.cuda() ?
cycle_loss = torch.nn.L1Loss() #.cuda()?
gan_loss = GANLoss().cuda()
feat_loss = torch.nn.modules.CrossEntropyLoss()






In [5]:
config = {
    'generator_S_T': Gen_S_T,
    'generator_T_S': Gen_T_S,
    'discriminator_S': Dis_S,
    'discriminator_T': Dis_T,
    'feature_dis': Dis_Feat,
    'regressor_S': regressor_S,
    'regressor_T': regressor_T,
    'optimizer_G': optimizer_G,
    'optimizer_D': optimizer_D,
    'optimizer_D_feat': optimizer_D_feat,
    'optimizer_R_S': optimizer_R_S,
    'optimizer_R_T': optimizer_R_T
}

In [7]:
class BaseFunction():
                                
    def set_requires_grad(self, nets, requires_grad=False):
        for net in nets:
            for param in net.parameters():
                param.requires_grad = requires_grad

    def class_prediction(self, out, target=None):
        out=out.detach().cuda()
        target=target.cuda()
        label = torch.where(out > 0.5, torch.ones_like(out).cuda(), torch.zeros_like(out).cuda())
        acc = (out.data == target).sum().item() / target.size()[0]
        return label, acc
    
    def visualize(self, source, target, fake_source, fake_target):
        fig, ax = plt.subplots(2,2, figsize=(8,8))
        source = source.T
        target = target.T
        fake_source= fake_source.T
        fake_target = fake_target.T
        
#         print(fake_source)
#         print(source.shape, target.shape, fake_source.shape,fake_target.shape)
        ax[0,0].imshow(source)
        
        ax[0,1].imshow(target)
        
        ax[1,0].imshow(fake_source)
        
        ax[1,1].imshow(fake_target)
        
        plt.show()

In [None]:
Minimax.

In [8]:
class CyCADA(nn.Module, BaseFunction):
    def __init__(self, config):
        super(CyCADA, self).__init__()

        self.gen_s_t = config['generator_S_T']
        self.gen_t_s = config['generator_T_S']
        self.dis_s = config['discriminator_S']
        self.dis_t = config['discriminator_T']
        self.feat_dis = config['feature_dis']
        self.regressor_s = config['regressor_S']
        self.regressor_t = config['regressor_T']
        self.optimizer_G = config['optimizer_G']
        self.optimizer_D = config['optimizer_D']
        self.optimizer_D_feat = config['optimizer_D_feat']
        self.optimizer_R_S = config['optimizer_R_S']
        self.optimizer_R_T = config['optimizer_R_T']
        
        self.ganloss = GANLoss() #.cuda()
        self.mseloss = torch.nn.MSELoss()
        self.cycleloss = torch.nn.L1Loss()
        self.featloss = torch.nn.modules.BCEWithLogitsLoss()
#________________________________________________________________________________________________________________        
    def forward(self,source_image, target_image, s_label):
        torch.autograd.set_detect_anomaly(True)
        
        self.source_image = source_image
        self.target_image = target_image
        self.label = s_label

        self.source_label = torch.ones(self.source_image.size()[0]).long().cuda() #source domain
        self.target_label = torch.zeros(self.target_image.size()[0]).long().cuda()#target domain
        
        #S-->T
        self.fake_target_img = self.gen_s_t(self.source_image) #gen(s)
        self.reconstructed_source = self.gen_t_s(self.fake_target_img)
        
        #T-->S
        self.fake_source_img = self.gen_t_s(self.target_image)
        self.reconstructed_target = self.gen_s_t(self.fake_source_img)
        

        ##using generated images in D_ft
        pred_source = self.feat_dis(self.regressor_t(self.fake_target_img.detach())) # D_ft(C_A(G_A(A)))
        pred_target = self.feat_dis(self.regressor_t(self.target_image))             # D_ft(C_B(B))
        
        _, self.acc_d_ft_source = self.class_prediction(pred_source.squeeze(), self.source_label)
        _, self.acc_d_ft_target = self.class_prediction(pred_target.squeeze(), self.target_label)
        self.score_acc_D_ft = (self.acc_d_ft_source + self.acc_d_ft_target)/2 #feature discriminator accuracy
        
        
    def backward_regressor_s(self):
        f_s_source = self.mseloss(self.regressor_s(self.source_image),self.regressor_s(self.fake_target_img)) 
        f_s_targ = self.mseloss(self.regressor_s(self.target_image), self.regressor_s(self.fake_source_img))
        self.f_s_loss = f_s_source + f_s_targ
        self.f_s_loss.backward()
    
    def backward_gen(self):
        torch.autograd.set_detect_anomaly(True)
        self.cycle_recon_loss_source = self.cycleloss(self.reconstructed_source.detach(), self.source_image)
        self.cycle_recon_loss_target = self.cycleloss(self.reconstructed_target.detach(), self.target_image)
        
        self.gan_loss_S_T = self.ganloss(self.dis_s(self.gen_s_t(self.source_image)), True) #*
        self.gan_loss_T_S = self.ganloss(self.dis_t(self.gen_t_s(self.target_image)), True)
        
        self.rloss_s = self.mseloss(self.regressor_s(self.source_image), self.regressor_s(self.fake_target_img.detach()))
        self.rloss_t = self.mseloss(self.regressor_t(self.target_image), self.regressor_t(self.fake_source_img.detach()))
        
        self.total_gen = self.rloss_s + self.rloss_t + self.cycle_recon_loss_source + self.cycle_recon_loss_target + self.gan_loss_S_T + self.gan_loss_T_S 
        self.total_gen.backward()

    def backward_dis(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.ganloss(pred_real, True)

        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.ganloss(pred_fake, False)

        loss_D = (loss_D_real + loss_D_fake) * 0.5

        # Calculate discriminator accuracy
        true_labels = torch.ones(real.size()[0]).long()
        fake_labels = torch.zeros(fake.detach().size()[0]).long()
        
        _, true_acc = self.class_prediction(pred_real.squeeze().cpu(), true_labels)
        _, fake_acc = self.class_prediction(pred_fake.squeeze().cpu(), fake_labels)
        acc = (true_acc + fake_acc) * 0.5
        return loss_D ,acc
    
    def backward_D_S(self):
        self.loss_D_S_from_T, self.score_acc_d_S  = self.backward_dis(self.dis_s, self.target_image, self.fake_target_img)
        self.loss_D_S_from_T.backward()

    def backward_D_T(self):
        self.loss_D_T_from_S, self.score_acc_d_T = self.backward_dis(self.dis_t, self.source_image, self.fake_source_img)
        self.loss_D_T_from_S.backward()
        
    def backward_regressor_T(self):
        pred_target = self.regressor_t(self.fake_target_img.detach()) #purple
        self.loss_reg_mse = self.mseloss(pred_target, self.label)     #purple
        
        if self.score_acc_D_ft > 0.1: 
            pred_target = self.feat_dis(self.regressor_t(self.target_image))
            target_label = torch.ones(self.target_image.size()[0]).cuda()
#             print(pred_target.shape, target_label.shape)
            self.loss_ft_D_targ = self.featloss(pred_target.squeeze(dim=1), target_label)
        else:
            self.loss_ft_D_targ = 0
        self.loss_reg_t = self.loss_reg_mse + self.loss_ft_D_targ
        self.loss_reg_t.backward()
    
            
    def backward_D_feat(self):
        # Source
        pred_source = self.feat_dis(self.regressor_t(self.fake_target_img.detach())) #d_feat class predictions

        loss_D_ft_s = self.featloss(pred_source.squeeze(dim=1), self.source_label.float())
        # Target
        pred_target = self.feat_dis(self.regressor_t(self.target_image))
        loss_D_ft_t = self.featloss(pred_target.squeeze(dim=1), self.target_label.float())
        # Combined loss
        self.loss_D_ft_adv = (loss_D_ft_s + loss_D_ft_t) * 0.5
        self.loss_D_ft_adv.backward()
        
    def runthrough(self, source_image, target_image, s_label):
        
        self.forward(source_image, target_image, s_label)
        
        self.set_requires_grad([self.regressor_s], True) #training source regressor
        self.optimizer_R_S.zero_grad()
        self.backward_regressor_s()
        self.optimizer_R_S.step()
        

#         generators
        self.set_requires_grad([self.dis_s, self.dis_t, self.regressor_s], False)
        self.set_requires_grad([self.gen_s_t, self.gen_t_s], True)
        self.optimizer_G.zero_grad()
        self.optimizer_R_S.zero_grad()
        self.backward_gen()
        self.optimizer_G.step()
        
        self.visualize(self.source_image[0].cpu().numpy(), 
               self.target_image[0].cpu().numpy(), 
               self.fake_source_img[0].detach().cpu().numpy(),
               self.fake_target_img[0].detach().cpu().numpy())


#         discriminators
        self.set_requires_grad([self.dis_s, self.dis_t], True)
        self.optimizer_D.zero_grad()
        self.backward_D_S()
        self.backward_D_T()
        self.optimizer_D.step()
        
        self.set_requires_grad([self.feat_dis], False)
        self.set_requires_grad([self.regressor_t], True)
        self.optimizer_R_T.zero_grad()
        self.backward_regressor_T()
        self.optimizer_R_T.step()
        
#           D_ft
        self.set_requires_grad([self.regressor_t], False)
        self.set_requires_grad([self.feat_dis], True)
        self.optimizer_D_feat.zero_grad()
        self.backward_D_feat()
        self.optimizer_D_feat.step()
        

In [1]:
i=0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CyCADA(config).to(device)

for i, (s, t) in enumerate(zip(source, target)):
    
    source_image = s['img'].to(device).float()
    target_image = t['img'].to(device).float()
    source_label = s['lbl'].to(device).float()
    
    model.train()
    
    model.runthrough(source_image, target_image, source_label)

    i+=1
    print(f'run:{i}')