Skip to content

Commit

Permalink
Merge pull request #135 from scipion-em/IH_DeepHEMNMA
Browse files Browse the repository at this point in the history
IH Fix inference parameters
  • Loading branch information
ilyes-hm committed Nov 14, 2022
2 parents 660ccc7 + bb00290 commit 326d7e7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
31 changes: 23 additions & 8 deletions continuousflex/protocols/utilities/deep_hemnma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def norm(imgs_path, output_path, FLAG, mode, batch_size):
random_seed = 42
validation_split = .2
shuffle_dataset = True

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor((1-validation_split) * dataset_size))
Expand All @@ -25,8 +26,8 @@ def norm(imgs_path, output_path, FLAG, mode, batch_size):

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
print('the train set size is: {} images'.format(len(train_sampler)))
print('the validation set size is: {} images'.format(len(valid_sampler)))
#print('the train set size is: {} images'.format(len(train_sampler)))
#print('the validation set size is: {} images'.format(len(valid_sampler)))
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
sum_, squared_sum_, num_batches = 0, 0, 0
Expand Down Expand Up @@ -60,8 +61,21 @@ def train(imgs_path, output_path, epochs=400, batch_size=2, lr=1e-4, flag=0, dev
DEVICE = 'cpu'
mean, std = norm(imgs_path, output_path, FLAG, mode, batch_size)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean), (std))])
dataset = cryodata(imgs_path, output_path, flag=FLAG, mode = mode, transform=transform)
dataset_size = len(dataset)
transform1 = transforms.Compose([transforms.ToTensor(),
transforms.RandomRotation((-45, 45)),
transforms.Normalize((mean), (std))])
dataset1 = cryodata(imgs_path, output_path, flag=FLAG, mode= mode,
transform=transform)
dataset2 = cryodata(imgs_path, output_path, flag=FLAG, mode= mode,
transform=transform1)
transform2 = transforms.Compose([transforms.ToTensor(),
transforms.RandomRotation((-90, 90)),
transforms.Normalize((mean), (std))])
dataset3 = cryodata(imgs_path, output_path, flag=FLAG, mode= mode,
transform=transform2)
increased_dataset = torch.utils.data.ConcatDataset([dataset1, dataset2, dataset3])
#dataset = cryodata(imgs_path, output_path, flag=FLAG, mode = mode, transform=transform)
dataset_size = len(increased_dataset)
indices = list(range(dataset_size))
split = int(np.floor((1-validation_split) * dataset_size))

Expand All @@ -72,10 +86,10 @@ def train(imgs_path, output_path, epochs=400, batch_size=2, lr=1e-4, flag=0, dev

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
print('the train set size is: {} images'.format(len(train_sampler)))
print('the validation set size is: {} images'.format(len(valid_sampler)))
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
print('the train set size is: {} images'.format(len(train_sampler)//3))
print('the validation set size is: {} images'.format(len(valid_sampler)//3))
train_loader = DataLoader(increased_dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(increased_dataset, batch_size=batch_size, sampler=valid_sampler)

im, p = next(iter(train_loader))
if FLAG=='nma':
Expand All @@ -97,6 +111,7 @@ def train(imgs_path, output_path, epochs=400, batch_size=2, lr=1e-4, flag=0, dev
running_loss = 0.0

for img, params in train_loader:
img = img/255.
optimizer.zero_grad()
pred_params = model(img.to(DEVICE), 'train')
l = criterion(params.to(DEVICE), pred_params)
Expand Down
24 changes: 17 additions & 7 deletions continuousflex/protocols/utilities/deep_hemnma_infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from continuousflex.protocols.utilities.processing_dh.data import cryodata
from continuousflex.protocols.utilities.processing_dh.utils import quater2euler, reverse_min_max
Expand All @@ -10,7 +8,17 @@
from pathlib import Path
import sys
import pwem.emlib.metadata as md

def norm(imgs_path, weights_path, flag, mode, batch_size):
dataset = cryodata(imgs_path, weights_path, flag=flag, mode=mode, transform=transforms.ToTensor())
train_loader = DataLoader(dataset, batch_size=batch_size)
sum_, squared_sum_, num_batches = 0, 0, 0
for img, image_name in train_loader:
sum_ += torch.mean(img, dim=[0, 2, 3])
squared_sum_ += torch.mean(img**2, dim=[0, 2, 3])
num_batches += 1
mean = sum_/num_batches
std = (squared_sum_/num_batches - mean**2)**0.5
return mean, std
def infer(imgs_path, weights_path, output_path, num_modes, batch_size=2, flag=0, device=0, mode='inference'):
FLAG = ''
if flag==0:
Expand All @@ -27,8 +35,9 @@ def infer(imgs_path, weights_path, output_path, num_modes, batch_size=2, flag=0,
else:
DEVICE = 'cpu'


dataset = cryodata(imgs_path, weights_path, flag=FLAG, mode = mode, transform=transforms.ToTensor())
mean, std = norm(imgs_path, weights_path, flag=FLAG, mode=mode, batch_size=batch_size)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean), (std))])
dataset = cryodata(imgs_path, weights_path, flag=FLAG, mode=mode, transform=transform)

dataset_size = len(dataset)
print('the train set size is: {} images'.format(dataset_size))
Expand All @@ -45,13 +54,14 @@ def infer(imgs_path, weights_path, output_path, num_modes, batch_size=2, flag=0,
model = deephemnma(2).to(DEVICE)
predictions = np.zeros((dataset_size, 2), dtype='float32')
elif FLAG=='all':
model = deephemnma(9).to(DEVICE)
model = deephemnma(6+num_modes).to(DEVICE)
predictions = np.zeros((dataset_size, 6+num_modes), dtype='float32')

model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
i = 0
for img, params in data_loader:
for img, image_name in data_loader:
pred_params = model(img.to(DEVICE), mode)
predictions[i * batch_size:(i + 1) * batch_size, :] = pred_params.cpu()
i+=1
Expand Down

0 comments on commit 326d7e7

Please sign in to comment.