# Notebook para treinar a GAN para síntese de imagens CT

### Definições iniciais

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import trange
import random
random.seed(5)
import matplotlib.pyplot as plt
import numpy as np
import os
import csv

In [None]:
from datasets import lungCTData
from model import Generator, Discriminator
from main import run_train_epoch, run_validation_epoch
from utils import clean_directory

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Definições do treino

In [None]:
#dados
start_point_train_data = 0
end_point_train_data = 30000
start_point_validation_data = 30000
end_point_validation_data = 40000
#bacthes
batch_size_train = 64
batch_size_validation = 8
#learning param
n_epochs = 200
initial_lr = 0.0002
epoch_to_switch_to_lr_scheduler = 100
#loss
criterion = torch.nn.BCELoss()
regularization = 10
steps_to_complete_bfr_upd_disc = 1
steps_to_complete_bfr_upd_gen = 1
#safe save
step_to_safe_save_models = 10
#save results directory
new_model = True
dir_save_results = './first_model/'
dir_save_models = dir_save_results+'models/'
dir_save_example = dir_save_results+'examples/'
name_model = 'my_first_model'

### Dados

In [None]:
processed_data_folder = '/home/arthur/Documentos/generativas/dgm-2024.2/projetos/PulmoNet/data/processed'

In [None]:
dataset_train = lungCTData(processed_data_folder=processed_data_folder,mode='train',start=start_point_train_data,end=end_point_train_data)
dataset_validation = lungCTData(processed_data_folder=processed_data_folder,mode='train',start=start_point_validation_data,end=end_point_validation_data)

In [None]:
data_loader_train = DataLoader(dataset_train, batch_size=batch_size_train, shuffle=True)
data_loader_validation = DataLoader(dataset_validation, batch_size=batch_size_validation, shuffle=True)

### Modelos

In [None]:
gen = Generator().to(device)
disc = Discriminator().to(device)

### Optimizers

In [None]:
gen_opt = torch.optim.Adam(gen.parameters(), lr=initial_lr, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(disc.parameters(), lr=initial_lr, betas=(0.5, 0.999))
gen_scheduler = torch.optim.lr_scheduler.LinearLR(gen_opt, start_factor=1.0, end_factor=0.0, total_iters=50)
disc_scheduler = torch.optim.lr_scheduler.LinearLR(disc_opt, start_factor=1.0, end_factor=0.0, total_iters=50)

### Loop de treino

mean_loss_train_gen_list = []
mean_loss_validation_gen_list = []
mean_loss_train_disc_list = []
mean_loss_validation_disc_list = []
save_count_idx = 0

os.makedirs(dir_save_results, exist_ok=True)
if new_model == True:
    clean_directory(dir_save_results)
os.makedirs(dir_save_models, exist_ok=True)
os.makedirs(dir_save_examples, exist_ok=True)
if new_model == True:
    with open(dir_save_results+'losses.csv', 'w', newline='') as csvfile:
        fieldnames = ['LossGenTrain', 'LossDiscTrain', 'LossGenVal', 'LoddDiscVal']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

In [None]:
for epoch in range(n_epochs):

    loss_train_gen, loss_train_disc = run_train_epoch(gen=gen, disc=disc, criterion=criterion, regularization=regularization, 
                                        data_loader=data_loader_train, disc_opt=disc_opt, gen_opt=gen_opt, 
                                        epoch=epoch, steps_to_complete_bfr_upd_disc=steps_to_complete_bfr_upd_disc, 
                                        steps_to_complete_bfr_upd_gen=steps_to_complete_bfr_upd_gen, device=device)

    mean_loss_train_gen_list.append(loss_train_gen)
    mean_loss_train_disc_list.append(loss_train_disc)

    loss_validation_gen, loss_validation_disc = run_validation_epoch(gen=gen, disc=disc, criterion=criterion, regularization=regularization, 
                                                data_loader=data_loader_validation, epoch=epoch, device=device)

    mean_loss_validation_gen_list.append(loss_validation_gen)
    mean_loss_validation_disc_list.append(loss_validation_disc)

    valid_on_the_fly(gen=gen, disc=disc, data_loader=data_loader_validation, epoch=epoch,save_dir=dir_save_example)

    if epoch%step_to_safe_save_models == 0:
        torch.save(gen.state_dict(), f"{self.dir_save_models}{name_model}_last_lr_{gen_scheduler.get_last_lr()[0]}_savesafe.pt")
        torch.save(disc.state_dict(), f"{self.dir_save_models}{name_model}_last_lr_{disc_scheduler.get_last_lr()[0]}_savesafe.pt")
        with open(dir_save_results+'losses_evolution.csv', mode='a', newline='') as file:
            writer = csv.writer(file)
            for i in range(save_count_idx,epoch+1):
                writer.writerow([mean_loss_train_gen_list[i], mean_loss_train_disc_list[i], 
                                mean_loss_validation_gen_list[i],mean_loss_validation_disc_list[i]])
        save_count_idx = epoch+1

    if epoch >= epoch_to_switch_to_lr_scheduler:
        gen_scheduler.step()
        disc_scheduler.setp()
        print("Current learning rate: gen: ", gen_scheduler.get_last_lr()[0], " disc: ", disc_scheduler.get_last_lr()[0])


In [None]:
torch.save(gen.state_dict(), f"{self.dir_save_models}{name_model}_trained.pt")
torch.save(disc.state_dict(), f"{self.dir_save_models}{name_model}_trained.pt")
if save_count_idx < n_epochs:
    with open(dir_save_results+'losses_evolution.csv', mode='a', newline='') as file:
        writer = csv.writer(file)
        for i in range(save_count_idx,epoch+1):
            writer.writerow([mean_loss_train_gen_list[i], mean_loss_train_disc_list[i], 
                            mean_loss_validation_gen_list[i],mean_loss_validation_disc_list[i]])