In [2]:
from generator import Generator
from critic import Critic
from dog_dataset import DogData

In [3]:
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torch import optim

In [8]:
batchSize = 10

In [9]:
transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ] )
dogdata = DogData(transform)
dogloader = DataLoader(dogdata, shuffle=True, batch_size=batchSize, num_workers=3)

In [10]:
print(dogdata[0][0])

tensor([[[0.0157, 0.0039, 0.0039,  ..., 0.0000, 0.0000, 0.0000],
         [0.0157, 0.0157, 0.0118,  ..., 0.0000, 0.0000, 0.0000],
         [0.0157, 0.0157, 0.0118,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0235, 0.0157, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
         [0.0157, 0.0078, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
         [0.0078, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.

In [5]:
def get_critic_scores(real_batch, batchSize, critic, generator):
    generator.batchSize = batchSize
    critic.batchSize = batchSize
    #generate fake images
    randoBatch = torch.randn([batchSize, 100, 1, 1])
    generated = generator(randoBatch)
    #compute scores of both real and fake using the critic
    real_scores = critic(real_batch[0])
    gen_scores = critic(generated)
    return ( torch.sum(real_scores) , torch.sum(gen_scores) )

In [6]:
def get_generator_score(batchSize, critic, generator):
    generator.batchSize = batchSize
    #generate fake images
    randoBatch = torch.randn([batchSize, 100, 1, 1])
    generated = generator(randoBatch)
    #compute scores of generated using critic
    gen_scores = critic(generated)
    return torch.sum(gen_scores)

In [40]:
class WeightClipper(object):

    def __init__(self, frequency=5):
        self.frequency = frequency

    def __call__(self, module):
        # filter the variables to get the ones you want
        if hasattr(module, 'weight'):
            w = module.weight.data
            w = w.clamp(-0.01,0.01)
            module.weight.data = w

In [54]:
critic = Critic()
generator = Generator()
clipper = WeightClipper()
critic.apply(clipper)

Critic(
  (convBlocks): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
  )
  (fcBlocks): Sequential(
    (0): Sequential(
      (0): Linear(in_features=67712, out_features=100, bias=True)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Se

In [55]:
num_epochs = 300
critic_epochs = 5
gen_epochs = 1

In [56]:
criticOptim = optim.RMSprop(critic.parameters(), lr=0.0005, momentum=0)
genOptim = optim.RMSprop(generator.parameters(), lr=0.05, momentum=0)

In [None]:
for epoch in range(num_epochs):
    print('Entering epoch yay: ' + str(epoch))
    realScore = 0.0 
    fakeScore = 0.0
    mainLoss = 0.0
    #critic training loop
    for batch_no, batch in enumerate(dogloader):
        print('Critic batch: ' + str(batch_no))
        scores = get_critic_scores(batch, batchSize, critic, generator)
        real_score = scores[0]
        gen_score = scores[1]
        #loss function
        loss = -( real_score - gen_score ) #+ 0.01 * ( torch.pow(real_score,2) + torch.pow(gen_score,2) ) #penalize for distance away from 0
        #some tracking of numbers here
        realScore += real_score.item()
        fakeScore += gen_score.item()
        mainLoss += loss.item()
        #optimize
        criticOptim.zero_grad()
        loss.backward()
        criticOptim.step()
        
        if (batch_no % critic_epochs == critic_epochs-1):
            break
    print(realScore)
    print(fakeScore)
    print(mainLoss)
    
    mainLoss = 0.0
    #generator training loop
    for gepoch in range(gen_epochs):
        print('Generator batch: ' + str(gepoch))
        gen_score = get_generator_score(batchSize, critic, generator)
        #loss function
        loss = -gen_score
        #tracking of numbers
        mainLoss += loss.item()
        #optimize
        genOptim.zero_grad()
        loss.backward()
        genOptim.step()
    print(mainLoss)

Entering epoch yay: 0
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.3182501196861267
-3.468188524246216
-0.1499384045600891
Generator batch: 0
0.763968288898468
Entering epoch yay: 1
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2965505719184875
-3.9409899711608887
-0.6444393992424011
Generator batch: 0
0.8688773512840271
Entering epoch yay: 2
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.30239200592041
-4.495228171348572
-1.1928361654281616
Generator batch: 0
1.0026843547821045
Entering epoch yay: 3
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.304612934589386
-5.361829161643982
-2.057216227054596
Generator batch: 0
1.172453761100769
Entering epoch yay: 4
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.3027010560035706
-5.845604419708252
-2.5429033637046814
Generator batch: 0
1.3068358898162842
Entering e

0.6588812470436096
Entering epoch yay: 41
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2944058179855347
-3.2944071888923645
-1.3709068298339844e-06
Generator batch: 0
0.6588811874389648
Entering epoch yay: 42
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.294405996799469
-3.2944063544273376
-3.5762786865234375e-07
Generator batch: 0
0.6588810682296753
Entering epoch yay: 43
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.294405937194824
-3.2944077849388123
-1.8477439880371094e-06
Generator batch: 0
0.6588815450668335
Entering epoch yay: 44
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2944052815437317
-3.2944076657295227
-2.384185791015625e-06
Generator batch: 0
0.6588815450668335
Entering epoch yay: 45
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2944050431251526
-3.294407367706299
-2.3245811462402344e-0

Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2943994402885437
-3.29440039396286
-9.5367431640625e-07
Generator batch: 0
0.6588801741600037
Entering epoch yay: 82
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.294399380683899
-3.2944003343582153
-9.5367431640625e-07
Generator batch: 0
0.6588799953460693
Entering epoch yay: 83
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2943989038467407
-3.2943997383117676
-8.344650268554688e-07
Generator batch: 0
0.6588802337646484
Entering epoch yay: 84
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.294398546218872
-3.29440039396286
-1.8477439880371094e-06
Generator batch: 0
0.6588801145553589
Entering epoch yay: 85
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2943981289863586
-3.2944007515907288
-2.6226043701171875e-06
Generator batch: 0
0.658880352973938
Entering epoch yay: 86
Critic batch: 0
Criti

-3.294382631778717
-3.294383943080902
-1.3113021850585938e-06
Generator batch: 0
0.6588767170906067
Entering epoch yay: 122
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2943820357322693
-3.2943829894065857
-9.5367431640625e-07
Generator batch: 0
0.6588760614395142
Entering epoch yay: 123
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2943809628486633
-3.2943820357322693
-1.0728836059570312e-06
Generator batch: 0
0.6588761806488037
Entering epoch yay: 124
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.294380307197571
-3.294380843639374
-5.364418029785156e-07
Generator batch: 0
0.6588761210441589
Entering epoch yay: 125
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batch: 3
Critic batch: 4
-3.2943795919418335
-3.294380843639374
-1.2516975402832031e-06
Generator batch: 0
0.6588760614395142
Entering epoch yay: 126
Critic batch: 0
Critic batch: 1
Critic batch: 2
Critic batc

In [None]:
import matplotlib.pyplot as plt
plt.imshow(dogdata[0][0].permute(1,2,0))
plt.show()

critic.batchSize = 1
critic.eval()
print(critic(dogdata[800][0].unsqueeze(0)))
rando = torch.randn([1, 100, 1, 1])
print(critic(generator(rando)))

gen = generator(rando)
plt.imshow(gen[0].permute(1,2,0).detach().numpy())
plt.show()