Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dcgan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ After every epoch, models are saved to: `netG_epoch_%d.pth` and `netD_epoch_%d.p
##Downloading the dataset
You can download the LSUN dataset by cloning [this repo](https://github.com/fyu/lsun) and running
```
python donwload.py -c bedroom
python download.py -c bedroom
```

##Usage
Expand Down
24 changes: 18 additions & 6 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda' , action='store_true', help='enables cuda')
parser.add_argument('--ngpu' , type=int, default=1, help='number of GPUs to use')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')

opt = parser.parse_args()
print(opt)
Expand All @@ -39,10 +40,14 @@
os.makedirs(opt.outf)
except OSError:
pass
opt.manualSeed = random.randint(1, 10000) # fix seed

if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

Expand Down Expand Up @@ -84,6 +89,7 @@
ndf = int(opt.ndf)
nc = 3


# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
Expand All @@ -93,6 +99,7 @@ def weights_init(m):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)


class _netG(nn.Module):
def __init__(self, ngpu):
super(_netG, self).__init__()
Expand All @@ -119,18 +126,21 @@ def __init__(self, ngpu):
nn.Tanh()
# state size. (nc) x 64 x 64
)

def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
return nn.parallel.data_parallel(self.main, input, gpu_ids)


netG = _netG(ngpu)
netG.apply(weights_init)
if opt.netG != '':
netG.load_state_dict(torch.load(opt.netG))
print(netG)


class _netD(nn.Module):
def __init__(self, ngpu):
super(_netD, self).__init__()
Expand All @@ -155,13 +165,15 @@ def __init__(self, ngpu):
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
return output.view(-1, 1)


netD = _netD(ngpu)
netD.apply(weights_init)
if opt.netD != '':
Expand Down Expand Up @@ -190,8 +202,8 @@ def forward(self, input):
fixed_noise = Variable(fixed_noise)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

for epoch in range(opt.niter):
for i, data in enumerate(dataloader, 0):
Expand Down Expand Up @@ -226,7 +238,7 @@ def forward(self, input):
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
label.data.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
errG.backward()
Expand Down