In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import torch
import numpy as np
import random
from datetime import datetime

import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import datasets, transforms
import torchvision.transforms.functional as FT
from torchvision import utils

from torch.utils.data import Dataset, DataLoader

import PIL.Image as Image
import matplotlib.pyplot as plt

import subprocess
# from generate_shapes_and_images import generate, generateImage
from model import Generator

# from myService.myModel import *
# from myService.myDataset import MyDataset
from myService.myUtils import my_collate
from myService.getImages import GetImages
from options import BaseOptions

from tqdm import tqdm
from PIL import Image
from torchvision.utils import save_image
from swinModels.swin_transformer import SwinTransformer

random.seed(datetime.now().timestamp())

Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.


In [2]:
epochs= 100
batch_size= 64

truncation_ratio = 0.5

In [3]:
inference_identities = 1

opt = BaseOptions().parse()
opt.camera.uniform = True
opt.model.is_test = True
opt.model.freeze_renderer = False
opt.rendering.offset_sampling = True
opt.rendering.static_viewdirs = True
opt.rendering.force_background = True
opt.rendering.perturb = 0
opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim = 64
opt.inference.style_dim = opt.model.style_dim
opt.inference.project_noise = opt.model.project_noise

# User options
model_type = 'ffhq' # Whether to load the FFHQ or AFHQ model
opt.inference.no_surface_renderings = True # When true, only RGB images will be created
opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated
opt.inference.identities = inference_identities # Number of identities to generate
opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if self.opt.inference.fixed_camera_angles is true.

opt.model.size = 1024
opt.experiment.expname = 'ffhq1024x1024'

usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]
                             [--config CONFIG] [--expname EXPNAME]
                             [--ckpt CKPT] [--continue_training]
                             [--checkpoints_dir CHECKPOINTS_DIR] [--iter ITER]
                             [--batch BATCH] [--chunk CHUNK]
                             [--val_n_sample VAL_N_SAMPLE]
                             [--d_reg_every D_REG_EVERY]
                             [--g_reg_every G_REG_EVERY]
                             [--local_rank LOCAL_RANK] [--mixing MIXING]
                             [--lr LR] [--r1 R1] [--view_lambda VIEW_LAMBDA]
                             [--eikonal_lambda EIKONAL_LAMBDA]
                             [--min_surf_lambda MIN_SURF_LAMBDA]
                             [--min_surf_beta MIN_SURF_BETA]
                             [--path_regularize PATH_REGULARIZE]
                             [--path_batch_shrink PATH_BATCH_SHRINK] [--wandb]
        

In [10]:
use_cuda = 1
device = torch.device("cuda" if (torch.cuda.is_available() & use_cuda) else "cpu")
print(device)

cuda


In [11]:
class MyDataset(Dataset):
    '''
    load the dataset
    '''
    def __init__(self, startIdx, endIdx, transform = None):
        self.startIdx = startIdx
        self.endIdx = endIdx
        
        camera_json_path = './prepareDataset/json/camera_paras.json'
        # z_json_path = './prepareDataset/json/sample_z_actual_used.json'
        z_json_path = './prepareDataset/json/sample_z.json'
        
        default_transform = transforms.Compose([
            # transforms.Resize((28,28)),
            transforms.ToTensor()
            ])

        with open(camera_json_path) as jsonFile:
            camera_paras = json.load(jsonFile)

        with open(z_json_path) as jsonFile:
            sample_z = json.load(jsonFile)
        
        self.camera_paras = camera_paras[startIdx: endIdx]
        self.sample_z = sample_z[startIdx: endIdx]
        # print(f'self.sample_z.shape: {np.array(self.sample_z).shape}')

        if transform == None:
            self.transform = default_transform
        else:
            self.transform = transform
        print('number of total data:{}'.format(len(self.camera_paras)))

    def __len__(self):
        return len(self.camera_paras)

    def __getitem__(self, idx):
        '''
        :param idx: Index of the image file
        :return: returns the image and corresponding label file.
        '''
        # read image with PIL module
        image_name = './prepareDataset/thumbs/' + str(self.startIdx + idx).rjust(7, "0") + ".png"
        image = Image.open(image_name, mode='r')
        image = image.convert('RGB')
        image = self.transform(image)

        # len: 256
        sample_z = self.sample_z[idx][0] 
        # print(np.array(sample_z).shape)

        # len: 12
        # camera_paras = np.array(self.camera_paras[self.startIdx + idx]["sample_cam_extrinsics"][0]).flatten().tolist() + self.camera_paras[self.startIdx + idx]["sample_locations"][0]
        camera_paras = np.array(self.camera_paras[idx]["sample_cam_extrinsics"][0]).flatten().tolist()

        # len: 256 + 12 = 268
        target = sample_z + camera_paras
        target = torch.tensor(target)

        return (image, target)

In [12]:
train_test_split_point = 8000
trainData = MyDataset(startIdx=0, endIdx=train_test_split_point, transform=None)
trainData_loader = DataLoader(trainData, batch_size=batch_size, num_workers=0,  collate_fn = my_collate, shuffle=True)

testData = MyDataset(startIdx=train_test_split_point, endIdx=10001, transform=None)
testData_loader = DataLoader(testData, batch_size=batch_size, num_workers=0,  collate_fn = my_collate, shuffle=False)

number of total data:8000
number of total data:2000


In [13]:
for data in trainData_loader:
    print(data[0].shape)
    print(data[1].shape)
    break

torch.Size([64, 3, 64, 64])
torch.Size([64, 268])


In [14]:
class ConvM(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        norm_layer = nn.BatchNorm2d
        super(ConvM, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            norm_layer(out_planes),
            nn.ReLU(inplace=True),
        )
class ConvNet(nn.Module):
    def __init__(self, n_class=16):
        super(ConvNet, self).__init__()
        
        self.conv = nn.Sequential(
            ConvM(3, 32, 5, 2),
            ConvM(32, 32, 5, 2),
            ConvM(32, 32, 3, 1),
            ConvM(32, 32, 3, 1),
        )        
        self.fc1 = nn.Linear(32, 256)
        self.fc2 = nn.Linear(256, 1000)
        self.fc3 = nn.Linear(1000, n_class)
    def forward(self, x):
        x = self.conv(x)
        x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)
class ResnetEncoder(nn.Module):
    # 202005251539 attr dim
    def __init__(self, input_nc=3, output_nc=3, n_blocks=3): 
        assert(n_blocks >= 0)
        super(ResnetEncoder, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        ngf = 64
        padding_type ='reflect'
        norm_layer = nn.InstanceNorm2d
        use_bias = False
        
        model = [nn.Conv2d(input_nc , ngf, kernel_size=7, padding=3,
                           bias=use_bias),
                 norm_layer(ngf, affine=True, track_running_stats=True),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=4,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2, affine=True, track_running_stats=True),
                      nn.ReLU(True)]
        mult = 2**n_downsampling
        
        for i in range(n_blocks):
            model += [ResidualBlock(dim_in=ngf * mult, dim_out=ngf * mult)]
        
        self.model = nn.Sequential(*model)
        # 65536
        self.fc1 = nn.Linear(256 * 16 * 16, 268)

    def forward(self, x):
        x = self.model(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [15]:
# build model
convNet = ConvNet(n_class=268).to(device)
resnetEncoder = ResnetEncoder().to(device)
swin_transformer = SwinTransformer(img_size=64, window_size=4, num_classes=268).to(device)

encoder = resnetEncoder

In [16]:
# # load model
# epoch = 10
# resnetEncoder_path = f"./result/models/{epoch}.ckpt"
# resnetEncoder.load_state_dict(torch.load(resnetEncoder_path, map_location=lambda storage, loc: storage))

In [17]:
# loss = torch.nn.CrossEntropyLoss().to(device)
lossModel_image = torch.nn.MSELoss().to(device)
# lossModel_image = torch.nn.CrossEntropyLoss().to(device)
# lossModel_image = torch.nn.L1Loss().to(device)
lossModel_latent = torch.nn.MSELoss().to(device)
# lossModel_latent = torch.nn.CrossEntropyLoss().to(device)
# lossModel_latent = torch.nn.L1Loss().to(device)
# loss = torch.nn.SmoothL1Loss().to(device)
# optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
optimizer = optim.Adam(encoder.parameters(), lr=0.001, betas=[0.5, 0.999])
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20,30,40,50,60,70,80,90], gamma=0.7)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.3)

In [18]:
# get g_ema
model_path = 'ffhq1024x1024.pt'
checkpoint_path = os.path.join('full_models', model_path)
checkpoint = torch.load(checkpoint_path)

g_ema = Generator(model_opt=opt.model, renderer_opt=opt.rendering, full_pipeline=True).to(device)

pretrained_weights_dict = checkpoint["g_ema"]
# pretrained_weights_dict = checkpoint["g"]
model_dict = g_ema.state_dict()
for k, v in pretrained_weights_dict.items():
    if v.size() == model_dict[k].size():
        model_dict[k] = v

g_ema.load_state_dict(model_dict)
g_ema.eval()

Generator(
  (style): Sequential(
    (0): MappingLinear(256, 256)
    (1): MappingLinear(256, 256)
    (2): MappingLinear(256, 256)
  )
  (renderer): VolumeFeatureRenderer(
    (network): SirenGenerator(
      (pts_linears): ModuleList(
        (0): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (1): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (2): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (3): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (4): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (5): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (6): FiLMSiren(
          (gamma): LinearLayer()
          (beta): LinearLayer()
        )
        (7): FiLMSiren(
          (gamma): LinearLayer()
          (b

In [19]:
with open("./prepareDataset/json/mean_latent.json", 'r') as jsonFile:
    mean_latent = json.load(jsonFile)
for i in range(len(mean_latent)):
    mean_latent[i] = torch.Tensor(mean_latent[i]).to(device=device)

In [20]:
def generateImage(latent, g_ema, seeIdx, fileName, save=True):
    # all are tensor
    latent = latent[seeIdx]
    
    sample_z = latent[0:256].reshape(1,256)
    sample_cam_extrinsics = torch.Tensor([[
        latent[256:260].tolist(),
        latent[260:264].tolist(),
        latent[264:268].tolist(),
    ]]).to(device)
    sample_focals = torch.Tensor([[
        [
            304.45965576171875
        ]
    ]]).to(device)
    sample_near = torch.Tensor([[
        [
            0.8799999952316284
        ]
    ]]).to(device)
    sample_far = torch.Tensor([[
        [
            1.1200000047683716
        ]
    ]]).to(device)
    out = g_ema([sample_z],
            sample_cam_extrinsics,
            sample_focals,
            sample_near,
            sample_far,
            truncation=truncation_ratio,
            truncation_latent=mean_latent)
    opt.renderer_output_size = 64
    rgb_images_thumbs = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
    rgb_images_thumbs = torch.cat([rgb_images_thumbs, out[1].cpu()], 0)

    rgb_images_thumbs = rgb_images_thumbs.reshape(3,64,64)
    # print(f'generate: {rgb_images_thumbs}')
    if save:
        utils.save_image(rgb_images_thumbs,
            # os.path.join(prepareDatasetPath, 'thumbs',f'{str(i).zfill(7)}.png'),
            f'./result/images/{fileName}.png',
            nrow=1,
            normalize=True,
            padding=0,
            value_range=(-1, 1),)
    # save_image([rgb_images_thumbs], f'./result/images/{fileName}.png')
    
    # image = Image.fromarray(rgb_images_thumbs.permute(1, 2, 0))
    # image.save(f'{fileName}.png', format='PNG')
    # %matplotlib inline
    # plt.imshow(  rgb_images_thumbs.permute(1, 2, 0)  )
    # plt.imshow(  rgb_images_thumbs  )
    return rgb_images_thumbs

def generateImageBatch(latent, para, g_ema):
    sample_cam_extrinsics = para[:, :12].reshape(-1,3,4)
    sample_focals = torch.Tensor([[
        [
            304.45965576171875
        ]
    ]]*batch_size).to(device)
    sample_near = torch.Tensor([[
        [
            0.8799999952316284
        ]
    ]]*batch_size).to(device)
    sample_far = torch.Tensor([[
        [
            1.1200000047683716
        ]
    ]]*batch_size).to(device)
    
    chunk = 2
    
    thumb_rgb = torch.Tensor(0, 3, 64, 64).to(device)
    thumb_rgb.requires_grad = True

    for j in range(0, batch_size, chunk):
        out = g_ema([latent[j:j+chunk]],
                    sample_cam_extrinsics[j:j+chunk],
                    sample_focals[j:j+chunk],
                    sample_near[j:j+chunk],
                    sample_far[j:j+chunk],
                    truncation=truncation_ratio,
                    truncation_latent=mean_latent)

        # rgb_images = torch.cat([rgb_images, out[0].cpu()], 0)
        thumb_rgb = torch.cat([thumb_rgb, out[1]], 0)

        del out
        torch.cuda.empty_cache()
    
    return thumb_rgb
    
    # all are tensor
    # rgb_images_thumbs_list = []
    # for i in range(len(latent)):
    #     rgb_images_thumbs = generateImage(latent=latent, g_ema=g_ema, seeIdx=i, fileName=None, save=False)
    #     rgb_images_thumbs_list.append(rgb_images_thumbs.tolist())
    
    # rgb_images_thumbs_list = torch.Tensor(rgb_images_thumbs_list, requires_grad=True).to(device=device)
    # return rgb_images_thumbs_list
    pass

def evalmodel(model, testloader, lossModel_latent, epoch):
    prob = 0.1
    model.eval()
    test_loss_latent = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(testloader):
            data, target = data.to(device), target.to(device)
            output_latent = model(data)
            # loss_cnn = lossModel_latent(output_latent,target)
            loss_latent = lossModel_latent(output_latent[:, :256],target[:, :256])
            test_loss_latent += loss_latent

            if random.random() < prob:
                seeIdx = random.randint(0, batch_size-1)
                # para
                # concate = torch.concat([target[:, :256], output_latent[:, 256:]], dim=1)
                # sample_z
                concate = torch.concat([output_latent[:, :256], target[:, 256:]], dim=1)
                
                # normal
                # generateImage(output_latent, g_ema, seeIdx, f'{epoch}_{batch_idx}_output')
                # para / sample_z
                # print(f'concate.shape: {concate.shape}')
                generateImage(concate, g_ema, seeIdx, f'{epoch}_{batch_idx * batch_size + seeIdx +train_test_split_point}_output')
                # generateImage(target, g_ema, seeIdx, f'{epoch}_{batch_idx * batch_size + seeIdx + train_test_split_point}_generateTarget')
                save_image([data[seeIdx]], f'./result/images/{epoch}_{batch_idx * batch_size + seeIdx + train_test_split_point}_realTarget.png')
                # print(f'real: {data[seeIdx]}')
                

    test_loss_latent /= len(testloader.dataset)
    # print(f'len(testloader.dataset): {len(testloader.dataset)}')
    return test_loss_latent

def train(model, optimizer, dataloader_train, testloader, lossModel_latent, lossModel_image, total_epoch, scheduler):
    useImageLoss = False
    
    # 步驟5. CNN模型開始訓練
    loss_train_list=[]
    loss_test_list=[]

    for epoch in range(total_epoch):
        scheduler.step()
        
        # train
        model.train()
        train_loss_latent = 0
        train_loss_image = 0
        # for batch_idx, (data, target) in enumerate(dataloader_train):
        for (data, target) in tqdm(dataloader_train):
        # for (data, target) in dataloader_train:
            data, target = data.to(device), target.to(device)
            # optimizer.zero_grad()
            output_cnn = model(data)
            
            # normal
            # loss_cnn = loss_latent(output_cnn,target) * loss_scaling
            # paras
            # loss_cnn = loss_latent(output_cnn[:, 256:],target[:, 256:])# * loss_scaling # train sample_z
            # sample_z
            loss_latent = lossModel_latent(output_cnn[:, :256],target[:, :256])# * loss_scaling # train sample_z
            # loss_latent2 = lossModel_latent(output_cnn,target)# * loss_scaling # train sample_z
            # print(loss_latent)
            # print(loss_latent2)
            # loss_latent = lossModel_latent(output_cnn,target)# * loss_scaling # train sample_z
            # concate = torch.concat([output_cnn[:, :256], target[:, 256:]], dim=1)

            # output_image = generateImageBatch(latent=output_cnn[:, :256], para=target[:, 256:], g_ema=g_ema)
            # loss_image = lossModel_image(output_image, data)
            # total_loss = loss_image + loss_latent
            # train_loss_latent += loss_latent
            # train_loss_image += loss_image

            # total_loss = loss_latent

            # print(loss_cnn)
            # print(loss_image)
            

            # loss_latent.backward()
            # loss_image.backward()
            # total_loss.backward()
            optimizer.zero_grad()
            loss_latent.backward()
            # print(loss_latent.item())
            optimizer.step()
            train_loss_latent += loss_latent.item()
            # print(loss_latent.item())
            # print(train_loss_latent)

        # print(train_loss_latent)
        train_loss_latent /= len(dataloader_train.dataset)
        # print(f'len(dataloader_train.dataset): {len(dataloader_train.dataset)}')
        if useImageLoss:
            train_loss_image /= len(dataloader_train.dataset)
        
    
        if epoch % 10 == 0:
            test_loss = evalmodel(model, testloader, lossModel_latent, epoch)
            
            # seeIdx = 0
            # # generateImage(output_cnn, g_ema, seeIdx, f'{epoch}_output')
            # concate = torch.concat([target[:, :256], output_cnn[:, 256:]], dim=1)
            # print(f'concate.shape: {concate.shape}')
            # generateImage(concate, g_ema, seeIdx, f'{epoch}_output')
            # generateImage(target, g_ema, seeIdx, f'{epoch}_target')
            
            # loss_train_list.append(total_loss)
            loss_test_list.append(test_loss)
            print('learning rate:{}'.format(scheduler.get_last_lr()[0]))
            print(F'CNN[epoch: [{epoch+1}/{total_epoch}], Average loss latent/image (Train):{train_loss_latent}/{train_loss_image},  Average loss latent (test):{test_loss}')

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f'./result/models/{epoch}.ckpt')

    print(F'CNN[epoch: [{epoch+1}/{total_epoch}], Average loss latent/image (Train):{train_loss_latent}/{train_loss_image},  Average loss latent (test):{test_loss}')
    print('training done.')
    
    return loss_train_list, loss_test_list

In [21]:
print('*'*50)
print('Training ... ')
print("resNetEncoder latent only")
loss_train_list, loss_test_list = train(encoder, optimizer, trainData_loader, testData_loader, lossModel_latent, lossModel_image, total_epoch=epochs, scheduler=scheduler)


**************************************************
Training ... 
resNetEncoder latent only


100%|██████████| 125/125 [00:14<00:00,  8.78it/s]


concate.shape: torch.Size([64, 268])
concate.shape: torch.Size([64, 268])
concate.shape: torch.Size([64, 268])
concate.shape: torch.Size([64, 268])
learning rate:0.001
CNN[epoch: [1/100], Average loss latent/image (Train):0.4979997551739216/0,  Average loss latent (test):0.05009942129254341


100%|██████████| 125/125 [00:12<00:00, 10.38it/s]
100%|██████████| 125/125 [00:11<00:00, 10.54it/s]
100%|██████████| 125/125 [00:11<00:00, 10.52it/s]
100%|██████████| 125/125 [00:11<00:00, 10.50it/s]
100%|██████████| 125/125 [00:11<00:00, 10.47it/s]
100%|██████████| 125/125 [00:12<00:00, 10.41it/s]
100%|██████████| 125/125 [00:12<00:00, 10.39it/s]
100%|██████████| 125/125 [00:12<00:00, 10.31it/s]
100%|██████████| 125/125 [00:11<00:00, 10.50it/s]
100%|██████████| 125/125 [00:11<00:00, 10.50it/s]


concate.shape: torch.Size([64, 268])
concate.shape: torch.Size([64, 268])
concate.shape: torch.Size([16, 268])
learning rate:0.001
CNN[epoch: [11/100], Average loss latent/image (Train):0.015947767414152623/0,  Average loss latent (test):0.02273315191268921


100%|██████████| 125/125 [00:11<00:00, 10.55it/s]
100%|██████████| 125/125 [00:11<00:00, 10.50it/s]
100%|██████████| 125/125 [00:11<00:00, 10.46it/s]
100%|██████████| 125/125 [00:12<00:00, 10.13it/s]
100%|██████████| 125/125 [00:12<00:00, 10.18it/s]
100%|██████████| 125/125 [00:12<00:00, 10.24it/s]
100%|██████████| 125/125 [00:12<00:00, 10.32it/s]
100%|██████████| 125/125 [00:11<00:00, 10.45it/s]
100%|██████████| 125/125 [00:12<00:00, 10.36it/s]
100%|██████████| 125/125 [00:12<00:00, 10.33it/s]


concate.shape: torch.Size([64, 268])
learning rate:0.001
CNN[epoch: [21/100], Average loss latent/image (Train):0.0086956601254642/0,  Average loss latent (test):0.02189939096570015


100%|██████████| 125/125 [00:12<00:00, 10.22it/s]
100%|██████████| 125/125 [00:12<00:00, 10.35it/s]
100%|██████████| 125/125 [00:12<00:00, 10.29it/s]
100%|██████████| 125/125 [00:12<00:00, 10.28it/s]
100%|██████████| 125/125 [00:12<00:00, 10.29it/s]
100%|██████████| 125/125 [00:12<00:00, 10.34it/s]
100%|██████████| 125/125 [00:12<00:00, 10.28it/s]
100%|██████████| 125/125 [00:12<00:00, 10.30it/s]
 78%|███████▊  | 98/125 [00:09<00:02, 10.35it/s]

In [None]:
for i in range(len(loss_train_list)):
    loss_train_list[i] = loss_train_list[i].cpu().detach().numpy()

for i in range(len(loss_test_list)):
    loss_test_list[i] = loss_test_list[i].cpu()

In [None]:
plt.plot(loss_train_list, color='red', label='train loss')
plt.plot(loss_test_list, color='blue', label='test loss')
plt.legend()
plt.savefig(f'./result/loss/lossBoth.png')
plt.cla()

plt.plot(loss_train_list, color='red', label='train loss')
plt.legend()
plt.savefig(f'./result/loss/lossTrain.png')
plt.cla()

plt.plot(loss_test_list, color='blue', label='test loss')
plt.legend()
plt.savefig(f'./result/loss/lossTest.png')