In [None]:
!git clone https://github.com/bioinfolabic/UTFPR-SBD3.git

Cloning into 'UTFPR-SBD3'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 60 (delta 0), reused 3 (delta 0), pack-reused 56[K
Receiving objects: 100% (60/60), 355.28 MiB | 9.38 MiB/s, done.
Resolving deltas: 100% (24/24), done.


In [None]:
%cd UTFPR-SBD3
!python3 extract_dataset.py

/content/UTFPR-SBD3
Finished!


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from torch.utils.data import Dataset
import torch.utils.data as utils
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
from PIL import Image
import os
import numpy as np
from tqdm import  tqdm
os.makedirs(f'{os.getcwd()}/masked_images', exist_ok=True)
for i in tqdm(os.listdir('images')):
  img = Image.open(f'images/{i}')
  img = np.array(img)
  mask_name = i.replace('jpg','png')
  mask = Image.open(f'masks/{mask_name}')
  mask = np.array(mask)
  img[mask==0] = 0
  img = Image.fromarray(img)
  img.save(f'masked_images/{i}')


100%|██████████| 4500/4500 [00:33<00:00, 132.48it/s]


In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, root_dir, transform_image=None,transform_mask=None):
        self.root_dir = root_dir
        self.transform_image = transform_image
        self.transform_mask  = transform_mask
        self.images_dir = os.path.join(root_dir, 'masked_images')  # Assuming images are in a folder named 'images'
        self.masks_dir = os.path.join(root_dir, 'annotations')   # Assuming masks are in a folder named 'masks'
        self.image_filenames = sorted(os.listdir(self.images_dir))
        self.mask_filenames = sorted(os.listdir(self.masks_dir))

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.masks_dir, self.mask_filenames[idx])

        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('L')  # 'L' mode for single-channel masks




        if self.transform_image and self.transform_mask:
            image = self.transform_image(image)
            mask = self.transform_mask(mask)

        return image,mask

# Example usage:
transform_image = transforms.Compose([transforms.Resize(size=(512, 512), interpolation=Image.NEAREST),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_mask = transforms.Compose([transforms.Resize(size=(512, 512), interpolation=Image.NEAREST),transforms.ToTensor()])
dataset = SegmentationDataset(root_dir='/content/UTFPR-SBD3', transform_image=transform_image,transform_mask= transform_mask)


In [None]:
class SegmentationTestDataset(Dataset):
    def __init__(self, root_dir, transform_image=None,transform_mask=None):
        self.root_dir = root_dir
        self.transform_image = transform_image
        self.transform_mask  = transform_mask
        self.images_dir = os.path.join(root_dir, 'masked_images')  # Assuming images are in a folder named 'images'
        self.masks_dir = os.path.join(root_dir, 'annotations')   # Assuming masks are in a folder named 'masks'
        self.image_filenames = ['01967.jpg', '02145.jpg', '01407.jpg', '01918.jpg', '00342.jpg']
        self.mask_filenames = ['01967.png', '02145.png', '01407.png', '01918.png', '00342.png']

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.masks_dir, self.mask_filenames[idx])

        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('L')  # 'L' mode for single-channel masks

        if self.transform_image and self.transform_mask:
            image = self.transform_image(image)
            mask = self.transform_mask(mask)

        return image,mask

In [None]:
batch_size=1
test_dataset = SegmentationTestDataset(root_dir='/content/UTFPR-SBD3', transform_image=transform_image,transform_mask= transform_mask)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                             shuffle=True)

In [None]:
batch_size=2
train_dataloader= torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torch import nn
import torch
import torchvision
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler

class UnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

    def forward(self, input):
        return self.model(input)


class UnetSkipConnectionBlock(nn.Module):

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):

        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:   # add skip connections
            return torch.cat([x, self.model(x)], 1)

gen = UnetGenerator(1, 3, 8, 64, norm_layer=nn.BatchNorm2d, use_dropout=False)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
        )
    def forward(self,x):
        x = self.model(x)
        return x
disc = Discriminator(4)

In [None]:
# gen_opt = torch.optim.SGD(gen.parameters(), weight_decay=1e-4, lr = 0.001, momentum=0.9)
# disc_opt = torch.optim.SGD(disc.parameters(), weight_decay=1e-4, lr = 0.0001, momentum=0.9)
gen = gen.to(device)
disc = disc.to(device)
gen.apply(weights_init)
disc.apply(weights_init)


gen_opt = torch.optim.Adam(gen.parameters(),lr=0.0002,betas=(0.5,0.999))
disc_opt = torch.optim.Adam(disc.parameters(),lr=0.0002,betas=(0.5,0.999))


l1_loss = nn.L1Loss()
bce_loss = nn.MSELoss()

In [None]:
def set_requires_grad(model, requires_grad=False):
    for param in model.parameters():
        param.requires_grad = requires_grad

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
gen.load_state_dict(torch.load('/content/drive/MyDrive/gen_512_48.pth'))
disc.load_state_dict(torch.load('/content/drive/MyDrive/disc_512_48.pth'))
# import shutil
# shutil.rmtree('/content/UTFPR-SBD3/Test_results')

<All keys matched successfully>

In [None]:
epoch_start=49

# checkpoint = torch.load('model.pt')
# gen.load_state_dict(checkpoint['gen_state_dict'])
# gen_opt.load_state_dict(checkpoint['genopt_state_dict'])
# disc.load_state_dict(checkpoint['disc_state_dict'])
# disc_opt.load_state_dict(checkpoint['discopt_state_dict'])
# epoch_start = checkpoint['epoch']


In [None]:
# new_learning_rate = 0.00002  # Set the new learning rate

# # Update the learning rate of the optimizer
# for param_group in disc_opt.param_groups:
#     param_group['lr'] = new_learning_rate

# for param_group in gen_opt.param_groups:
#     param_group['lr'] = new_learning_rate

In [None]:
epochs=100
os.makedirs('Test_results_up',exist_ok=True)
for epoch in range(epoch_start,epochs+epoch_start):
  gen.train()
  disc.train()


  for (images,masks) in tqdm(train_dataloader,total=len(train_dataloader)):
    images,masks = images.to(device),masks.to(device)
    # discriminator training
    set_requires_grad(disc,requires_grad=True)
    # set_requires_grad(gen,requires_grad=False)
    disc_opt.zero_grad()
    fake_images = gen(masks)
    fake_images_conc = torch.concat([fake_images,masks],1)
    real_images_conc = torch.concat([images,masks],1)

    real_disc_out = disc(real_images_conc)
    fake_disc_out = disc(fake_images_conc)

    # real_labels = torch.ones_like(real_disc_out)- 0.1 * torch.rand_like(real_disc_out)
    # fake_labels = torch.zeros_like(fake_disc_out) + 0.3 * torch.rand_like(fake_disc_out)

    real_labels = torch.ones_like(real_disc_out).cuda()
    fake_labels = torch.zeros_like(fake_disc_out).cuda()

    # code of gradient penalty

    # alpha = torch.rand(images.size(0), 1, 1, 1).to(device)
    # interpolated = (alpha * images + (1 - alpha) * fake_images).requires_grad_(True)
    # disc_interpolated = disc(torch.cat([interpolated, masks], 1))
    # gradients = torch.autograd.grad(outputs=disc_interpolated, inputs=interpolated,
    #                                 grad_outputs=torch.ones_like(disc_interpolated),
    #                                 create_graph=True, retain_graph=True, only_inputs=True)[0]
    # gradient_penalty = lambda_gp * ((gradients.norm(2, dim=1) - 1) ** 2).mean()



    real_disc_loss = bce_loss(real_disc_out,real_labels)
    fake_disc_loss = bce_loss(fake_disc_out,fake_labels)

    # total_disc_loss = (real_disc_loss + fake_disc_loss)*0.5 + gradient_penalty
    total_disc_loss = (real_disc_loss + fake_disc_loss)*0.5
    total_disc_loss.backward()
    disc_opt.step()

    # generator training
    set_requires_grad(disc,requires_grad=False)
    # set_requires_grad(gen,requires_grad=True)
    gen_opt.zero_grad()
    fake_images1 = gen(masks)

    gen_loss = l1_loss(fake_images1,images)

    fake_images_conc1 = torch.concat([fake_images1,masks],1)

    disc_out = disc(fake_images_conc1)

    disc_loss = bce_loss(disc_out,torch.ones_like(disc_out))

    total_gen_loss = 100*gen_loss + disc_loss

    total_gen_loss.backward()
    gen_opt.step()

  with torch.no_grad():

    gen.eval()
    for kk,(images,masks) in enumerate(test_dataloader):

      images,masks = images.to(device),masks.to(device)
      result = gen(masks)
      output_tensor = result.squeeze()
      denormalize = transforms.Normalize((-1, -1, -1), (2, 2, 2))
      denormalized_tensor = denormalize(output_tensor)
      output_numpy = denormalized_tensor.cpu().numpy()
      output_image = Image.fromarray((output_numpy * 255).astype(np.uint8).transpose(1, 2, 0))
      output_image.save(f'/content/UTFPR-SBD3/Test_results/{epoch}_{kk}_fake.jpg')





  print(f'Epoch [{epoch + 1}], Discriminator Loss: {total_disc_loss.item()}, Generator Loss: {total_gen_loss.item()}')

100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [50], Discriminator Loss: 0.05142281949520111, Generator Loss: 4.250527858734131


100%|██████████| 2250/2250 [20:42<00:00,  1.81it/s]


Epoch [51], Discriminator Loss: 0.15634483098983765, Generator Loss: 3.9832491874694824


100%|██████████| 2250/2250 [20:42<00:00,  1.81it/s]


Epoch [52], Discriminator Loss: 0.07578960806131363, Generator Loss: 3.6275548934936523


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [53], Discriminator Loss: 0.06491173803806305, Generator Loss: 4.302152156829834


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [54], Discriminator Loss: 0.18794973194599152, Generator Loss: 3.1014275550842285


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [55], Discriminator Loss: 0.13610419631004333, Generator Loss: 3.6736156940460205


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [56], Discriminator Loss: 0.10339955985546112, Generator Loss: 3.3466906547546387


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [57], Discriminator Loss: 0.06339671462774277, Generator Loss: 4.975586891174316


100%|██████████| 2250/2250 [20:39<00:00,  1.81it/s]


Epoch [58], Discriminator Loss: 0.0782412439584732, Generator Loss: 4.855637550354004


100%|██████████| 2250/2250 [20:39<00:00,  1.81it/s]


Epoch [59], Discriminator Loss: 0.09039057046175003, Generator Loss: 4.125000476837158


100%|██████████| 2250/2250 [20:39<00:00,  1.81it/s]


Epoch [60], Discriminator Loss: 0.16032296419143677, Generator Loss: 3.3431930541992188


100%|██████████| 2250/2250 [20:39<00:00,  1.81it/s]


Epoch [61], Discriminator Loss: 0.1505441963672638, Generator Loss: 3.383056163787842


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [62], Discriminator Loss: 0.15682171285152435, Generator Loss: 3.295691728591919


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [63], Discriminator Loss: 0.10158611834049225, Generator Loss: 3.985898494720459


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [64], Discriminator Loss: 0.07510465383529663, Generator Loss: 4.993413925170898


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [65], Discriminator Loss: 0.21557122468948364, Generator Loss: 3.197904109954834


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [66], Discriminator Loss: 0.11566153168678284, Generator Loss: 3.0970935821533203


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [67], Discriminator Loss: 0.1746547818183899, Generator Loss: 3.397085428237915


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [68], Discriminator Loss: 0.14645393192768097, Generator Loss: 3.4525327682495117


100%|██████████| 2250/2250 [20:39<00:00,  1.82it/s]


Epoch [69], Discriminator Loss: 0.06067577004432678, Generator Loss: 4.547272682189941


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [70], Discriminator Loss: 0.10521315038204193, Generator Loss: 3.7592687606811523


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [71], Discriminator Loss: 0.15120413899421692, Generator Loss: 3.2021265029907227


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [72], Discriminator Loss: 0.06473126262426376, Generator Loss: 3.6148080825805664


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [73], Discriminator Loss: 0.11658594012260437, Generator Loss: 5.493468761444092


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [74], Discriminator Loss: 0.04914376884698868, Generator Loss: 5.448707580566406


100%|██████████| 2250/2250 [20:43<00:00,  1.81it/s]


Epoch [75], Discriminator Loss: 0.10962842404842377, Generator Loss: 3.81827974319458


100%|██████████| 2250/2250 [20:43<00:00,  1.81it/s]


Epoch [76], Discriminator Loss: 0.14727994799613953, Generator Loss: 3.243567943572998


100%|██████████| 2250/2250 [20:42<00:00,  1.81it/s]


Epoch [77], Discriminator Loss: 0.2049919068813324, Generator Loss: 4.025151252746582


100%|██████████| 2250/2250 [20:40<00:00,  1.81it/s]


Epoch [78], Discriminator Loss: 0.1320689618587494, Generator Loss: 3.033667802810669


100%|██████████| 2250/2250 [20:43<00:00,  1.81it/s]


Epoch [79], Discriminator Loss: 0.16855354607105255, Generator Loss: 2.9431166648864746


100%|██████████| 2250/2250 [20:43<00:00,  1.81it/s]


Epoch [80], Discriminator Loss: 0.12965603172779083, Generator Loss: 3.991396427154541


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [81], Discriminator Loss: 0.06179840862751007, Generator Loss: 4.395103454589844


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [82], Discriminator Loss: 0.16433444619178772, Generator Loss: 3.3071112632751465


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [83], Discriminator Loss: 0.12375441938638687, Generator Loss: 3.8614988327026367


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [84], Discriminator Loss: 0.11123117804527283, Generator Loss: 3.6917550563812256


100%|██████████| 2250/2250 [20:42<00:00,  1.81it/s]


Epoch [85], Discriminator Loss: 0.13066703081130981, Generator Loss: 4.206608772277832


100%|██████████| 2250/2250 [20:42<00:00,  1.81it/s]


Epoch [86], Discriminator Loss: 0.16848033666610718, Generator Loss: 3.0619077682495117


100%|██████████| 2250/2250 [20:41<00:00,  1.81it/s]


Epoch [87], Discriminator Loss: 0.10278636962175369, Generator Loss: 4.1566243171691895


100%|██████████| 2250/2250 [20:31<00:00,  1.83it/s]


Epoch [88], Discriminator Loss: 0.27841177582740784, Generator Loss: 2.886536121368408


100%|██████████| 2250/2250 [20:42<00:00,  1.81it/s]


Epoch [89], Discriminator Loss: 0.096564881503582, Generator Loss: 4.090121269226074


 61%|██████▏   | 1383/2250 [12:42<07:59,  1.81it/s]

In [None]:
torch.save(gen.state_dict(), f'/content/drive/MyDrive/gen_512_100.pth')
torch.save(disc.state_dict(), f'/content/drive/MyDrive/disc_512_100.pth')

In [None]:
EPOCH = epoch
PATH = "model.pt"

torch.save({
            'epoch': EPOCH,
            'gen_state_dict': gen.state_dict(),
            'genopt_state_dict': gen_opt.state_dict(),
            'disc_state_dict': disc.state_dict(),
            'discopt_state_dict': disc_opt.state_dict(),

            }, PATH)

In [None]:
test_dataset = SegmentationDataset('Test',transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)

In [None]:
# model = Custom(n_class=2)
# model.load_state_dict(torch.load('point.pth'))

In [None]:
with torch.no_grad():
    # torch.save(gen.state_dict(), f'/content/drive/MyDrive/gen_{epoch}.pth')
    # torch.save(disc.state_dict(), f'/content/drive/MyDrive/disc_{epoch}.pth')
    gen.eval()
    for kk,(images,masks) in enumerate(test_dataloader):

      images,masks = images.to(device),masks.to(device)
      result = gen(masks)
      output_tensor = result.squeeze()
      denormalize = transforms.Normalize((-1, -1, -1), (2, 2, 2))
      denormalized_tensor = denormalize(output_tensor)
      output_numpy = denormalized_tensor.cpu().numpy()
      # Rescale values to the original range [0, 1]
      output_image = Image.fromarray((output_numpy * 255).astype(np.uint8).transpose(1, 2, 0))
      # Show or save the resulting image
      output_image.save(f'/content/UTFPR-SBD3/{kk}_fake.jpg')


In [None]:
os.makedirs('/content/drive/MyDrive/clothes_unet_',exist_ok=True)

In [None]:
torch.save(gen.state_dict(), f'/content/drive/MyDrive/clothes_unet_/gen_{epoch}.pth')
torch.save(disc.state_dict(), f'/content/drive/MyDrive/clothes_unet_/disc_{epoch}.pth')

In [None]:
disc.load_state_dict(torch.load('/content/drive/MyDrive/clothes_unet_/disc_99.pth'))
gen.load_state_dict(torch.load('/content/drive/MyDrive/clothes_unet_/gen_99.pth'))

<All keys matched successfully>

In [None]:
gen = gen.to(device)
disc = disc.to(device)

In [None]:

gen_drive_link : https://drive.google.com/file/d/1-1zNpaTmJLTc3sGeqzx-HKqc8s_9QiQu/view?usp=share_link
disk_drive_link: https://drive.google.com/file/d/1-2KyHlw-q3x84UWG8SlwZs07ZEzRaHc8/view?usp=share_link

gen_drive_link :