In [5]:
%run network.ipynb
import torch.optim as optim
import random
import os
from PIL import Image
from torchvision import transforms, datasets
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    folder_path = "training_images"
    output_folder = "generated_images"
    png_files = [f for f in os.listdir(folder_path) if f.lower().endswith('.png')]
    target_shape = (128, 128)
    seed_size = 128
    
    # Initialize the neuro networks
    disc = GAN_Dis().to(device)
    batch_size = 200
    optim_d = optim.Adam(disc.parameters(), lr=0.0002)
    gen = GAN_Gen().to(device)
    optim_g = optim.Adam(gen.parameters(), lr=0.0002)
    load_w(disc, optim_d, "disc.pkl")
    load_w(gen, optim_g, "gen.pkl")
    
    # Train parameters
    iterations = 50000
    m = 0
    s = 1
    normal_shape = (batch_size, seed_size)
    crit = nn.BCELoss()
    for i in range(iterations):
        print("Iteration", i)
        train_disc = True
        train_gen = True
        if i % 50 == 0:
            file_name = str(i) + ".jpg"
            save_path = os.path.join(output_folder, file_name)
            gen.generate_image(save_path)
            save_w(disc, optim_d, "disc.pkl")
            save_w(gen, optim_g, "gen.pkl")
            #################################################
            seed0 = torch.normal(m, s, (2, 128)).to(device)
            selected_files = random.sample(png_files, 2)
            real0 = load_image(folder_path, target_shape, selected_files)
            gen_sample0 = gen(seed0)
            fake_p = disc(gen_sample0)
            real_p = disc(real0)
            print("Real probability", real_p)
            print("Fake probability", fake_p)
        
        #################################################  
        seed = torch.normal(m, s, normal_shape).to(device)
        selected_files = random.sample(png_files, batch_size)
        real = load_image(folder_path, target_shape, selected_files)
        gen_sample = gen(seed)
        disc_loss = crit(disc(real), torch.ones(batch_size)) + crit(disc(gen_sample.detach()), torch.zeros(batch_size))
        disc_loss.backward()
        optim_d.step()
        optim_g.zero_grad()
        optim_d.zero_grad()
        
        ###################################################
        gen_loss = - crit(disc(gen_sample), torch.zeros(batch_size))
        gen_loss.backward()
        optim_g.step()
        optim_g.zero_grad()
        optim_d.zero_grad()

        print()
        
def save_w(model, optimizer, name):
    os.makedirs("saved_model", exist_ok=True)
    save_path = os.path.join("saved_model", name)
    torch.save(dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict()
    ), save_path)

def load_w(model, optimizer, name):
    log_dir = os.path.abspath(os.path.expanduser("saved_model"))
    save_path = os.path.join(log_dir, name)
    if os.path.isfile(save_path):
        state_dict = torch.load(
            save_path,
            torch.device('cpu') if not torch.cuda.is_available() else None
        )
        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])
        print("Successfully loaded weights from {}!".format(save_path))
        return True
    else:
        raise ValueError("Failed to load weights from {}! File does not exist!".format(save_path))
        
def load_image(folder_path, target_shape, file_names):
    train_data = []
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    for f in file_names:
        p = os.path.join(folder_path, f)
        image = Image.open(p)
        tensor_image = transform(image)
        train_data.append(torch.unsqueeze(tensor_image, 0))
    cat_train = torch.cat(train_data, 0)
    return cat_train
        
train()

Successfully loaded weights from C:\Users\shiyu\OneDrive\Desktop\CAT-GAN\saved_model\disc.pkl!
Successfully loaded weights from C:\Users\shiyu\OneDrive\Desktop\CAT-GAN\saved_model\gen.pkl!
Iteration 0
Real probability tensor([0.6509, 0.7494], grad_fn=<ViewBackward0>)
Fake probability tensor([0.0501, 0.0236], grad_fn=<ViewBackward0>)

Iteration 1

Iteration 2

Iteration 3

Iteration 4

Iteration 5

Iteration 6

Iteration 7

Iteration 8

Iteration 9

Iteration 10

Iteration 11

Iteration 12

Iteration 13

Iteration 14

Iteration 15

Iteration 16

Iteration 17

Iteration 18

Iteration 19

Iteration 20

Iteration 21

Iteration 22

Iteration 23

Iteration 24

Iteration 25

Iteration 26

Iteration 27

Iteration 28

Iteration 29

Iteration 30

Iteration 31

Iteration 32

Iteration 33

Iteration 34

Iteration 35

Iteration 36

Iteration 37

Iteration 38

Iteration 39

Iteration 40

Iteration 41

Iteration 42

Iteration 43

Iteration 44

Iteration 45

Iteration 46

Iteration 47

Iteration 48

I


Iteration 452

Iteration 453

Iteration 454

Iteration 455

Iteration 456

Iteration 457

Iteration 458

Iteration 459

Iteration 460

Iteration 461

Iteration 462

Iteration 463

Iteration 464

Iteration 465

Iteration 466

Iteration 467

Iteration 468

Iteration 469

Iteration 470

Iteration 471

Iteration 472

Iteration 473

Iteration 474

Iteration 475

Iteration 476

Iteration 477

Iteration 478

Iteration 479

Iteration 480

Iteration 481

Iteration 482

Iteration 483

Iteration 484

Iteration 485

Iteration 486

Iteration 487

Iteration 488

Iteration 489

Iteration 490

Iteration 491

Iteration 492

Iteration 493

Iteration 494

Iteration 495

Iteration 496

Iteration 497

Iteration 498

Iteration 499

Iteration 500
Real probability tensor([0.7898, 0.7964], grad_fn=<ViewBackward0>)
Fake probability tensor([0.1193, 0.4308], grad_fn=<ViewBackward0>)

Iteration 501

Iteration 502

Iteration 503

Iteration 504

Iteration 505

Iteration 506

Iteration 507

Iteration 508

Iteration 


Iteration 918

Iteration 919

Iteration 920

Iteration 921

Iteration 922

Iteration 923

Iteration 924

Iteration 925

Iteration 926

Iteration 927

Iteration 928

Iteration 929

Iteration 930

Iteration 931

Iteration 932

Iteration 933

Iteration 934

Iteration 935

Iteration 936

Iteration 937

Iteration 938

Iteration 939

Iteration 940

Iteration 941

Iteration 942

Iteration 943

Iteration 944

Iteration 945

Iteration 946

Iteration 947

Iteration 948

Iteration 949

Iteration 950
Real probability tensor([0.4484, 0.8081], grad_fn=<ViewBackward0>)
Fake probability tensor([0.1800, 0.2098], grad_fn=<ViewBackward0>)

Iteration 951

Iteration 952

Iteration 953

Iteration 954

Iteration 955

Iteration 956

Iteration 957

Iteration 958

Iteration 959

Iteration 960

Iteration 961

Iteration 962

Iteration 963

Iteration 964

Iteration 965

Iteration 966

Iteration 967

Iteration 968

Iteration 969

Iteration 970

Iteration 971

Iteration 972

Iteration 973

Iteration 974

Iteration 

KeyboardInterrupt: 