In [None]:
%run network.ipynb
import torch.optim as optim
import random
import os
from PIL import Image
from torchvision import transforms, datasets
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    folder_path = "training_images"
    output_folder = "generated_images"
    png_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.png')]
    target_shape = (128, 128)
    seed_size = 128
    
    # Initialize the neuro networks
    disc = GAN_Dis().to(device)
    batch_size = 200
    optim_d = optim.Adam(disc.parameters(), lr=0.0002)
    gen = GAN_Gen().to(device)
    optim_g = optim.Adam(gen.parameters(), lr=0.0002)
    load_w(disc, optim_d, "disc.pkl")
    load_w(gen, optim_g, "gen.pkl")
    
    # Train parameters
    iterations = 50000
    m = 0
    s = 1
    normal_shape = (batch_size, seed_size)
    crit = nn.BCELoss()
    for i in range(iterations):
        print("Iteration", i)
        train_disc = True
        train_gen = True
        if i % 50 == 0:
            file_name = str(i) + ".jpg"
            save_path = os.path.join(output_folder, file_name)
            gen.generate_image(save_path)
            save_w(disc, optim_d, "disc.pkl")
            save_w(gen, optim_g, "gen.pkl")
            # Print progress
            seed0 = torch.normal(m, s, (2, 128)).to(device)
            selected_files = random.sample(png_files, 2)
            real0 = load_image(folder_path, target_shape, selected_files)
            gen_sample0 = gen(seed0)
            fake_p = disc(gen_sample0)
            real_p = disc(real0)
            print("Real probability", real_p)
            print("Fake probability", fake_p)
        
        # Discriminator training
        seed = torch.normal(m, s, normal_shape).to(device)
        selected_files = random.sample(png_files, batch_size)
        real = load_image(folder_path, target_shape, selected_files)
        gen_sample = gen(seed)
        disc_loss = crit(disc(real), torch.ones(batch_size)) + crit(disc(gen_sample.detach()), torch.zeros(batch_size))
        disc_loss.backward()
        optim_d.step()
        optim_g.zero_grad()
        optim_d.zero_grad()
        
        # Generator Training
        gen_loss = - crit(disc(gen_sample), torch.zeros(batch_size))
        gen_loss.backward()
        optim_g.step()
        optim_g.zero_grad()
        optim_d.zero_grad()

        print()
        
def save_w(model, optimizer, name):
    os.makedirs("saved_model", exist_ok=True)
    save_path = os.path.join("saved_model", name)
    torch.save(dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict()
    ), save_path)

def load_w(model, optimizer, name):
    log_dir = os.path.abspath(os.path.expanduser("saved_model"))
    save_path = os.path.join(log_dir, name)
    if os.path.isfile(save_path):
        state_dict = torch.load(
            save_path,
            torch.device('cpu') if not torch.cuda.is_available() else None
        )
        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])
        print("Successfully loaded weights from {}!".format(save_path))
        return True
    else:
        raise ValueError("Failed to load weights from {}! File does not exist!".format(save_path))
        
def load_image(folder_path, target_shape, file_names):
    train_data = []
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    for f in file_names:
        p = os.path.join(folder_path, f)
        image = Image.open(p)
        tensor_image = transform(image)
        train_data.append(torch.unsqueeze(tensor_image, 0))
    cat_train = torch.cat(train_data, 0)
    return cat_train
        
train()