Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to print the output of the wrong prediction of validation dataset? #19

Closed
Williamlizl opened this issue Sep 16, 2021 · 10 comments
Closed

Comments

@Williamlizl
Copy link

No description provided.

@zihangJiang
Copy link
Owner

You may refer to the code here to compare the output (prediction) and the target (ground truth).

TokenLabeling/validate.py

Lines 238 to 242 in 09bb641

# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))

@Williamlizl
Copy link
Author

You may refer to the code here to compare the output (prediction) and the target (ground truth).

TokenLabeling/validate.py

Lines 238 to 242 in 09bb641

# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))

And if I want to get the dir with the prediction , ?

@zihangJiang
Copy link
Owner

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

@Williamlizl
Copy link
Author

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

Is there no test.py to inference?

@zihangJiang
Copy link
Owner

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

Is there no test.py to inference?

You can use this colab notebook for inference. It uses VOLO model, but you can simply change the model by from tlt.models import lvvit_s and download the pre-trained model here

@Williamlizl
Copy link
Author

Williamlizl commented Sep 17, 2021

To get the path of the images, you may refer to

class ImageDatasetWithIndex(ImageDataset):
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target, index

Is there no test.py to inference?

You can use this colab notebook for inference. It uses VOLO model, but you can simply change the model by from tlt.models import lvvit_s and download the pre-trained model here

from tlt.models import lvvit_s from PIL import Image from tlt.utils import load_pretrained_weights from timm.data import create_transform model = lvvit_s(img_size=384) load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912- 114053-lvvit_s-384/model_best.pth.tar') model.eval() transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) image = Image.open('/home/lbc/GitHub/c/train/LV-ViT/validation/1_val/323_l2.jpg') input_image = transform(image).unsqueeze(0)
` RuntimeError Traceback (most recent call last)
in
4 from timm.data import create_transform
5 model = lvvit_s(img_size=384)
----> 6 load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912-114053-lvvit_s-384/model_best.pth.tar')
7 model.eval()
8 transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct'])

~/.local/lib/python3.7/site-packages/tlt/utils/utils.py in load_pretrained_weights(model, checkpoint_path, use_ema, strict, num_classes)
109 def load_pretrained_weights(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):
110 state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)
--> 111 model.load_state_dict(state_dict, strict=strict)
112
113

~/.local/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1222 if len(error_msgs) > 0:
1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224 self.class.name, "\n\t".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
1226

RuntimeError: Error(s) in loading state_dict for LV_ViT:
Missing key(s) in state_dict: "head.weight", "head.bias", "aux_head.weight", "aux_head.bias". `

@zihangJiang
Copy link
Owner

Please use the latest version of our repo. (pip install tlt==0.2.0)
This is a bug of the function in tlt/utils.py in our early version which delete all classification heads in order to do transfer learning.

@Williamlizl
Copy link
Author

Please use the latest version of our repo. (pip install tlt==0.2.0)
This is a bug of the function in tlt/utils.py in our early version which delete all classification heads in order to do transfer learning.

from tlt.models import lvvit_s from PIL import Image from tlt.utils import load_pretrained_weights from timm.data import create_transform model = lvvit_s(img_size=384) load_pretrained_weights(model=model, checkpoint_path='/home/lbc/GitHub/c/train/LV-ViT/20210912-114053-lvvit_s-384/model_best.pth.tar',strict=False,num_classes=2) model.eval() print(model) transform = create_transform(input_size=384, crop_pct=model.default_cfg['crop_pct']) image = Image.open('/home/lbc/GitHub/c/train/LV-ViT/validation/1_val/323_l2.jpg') input_image = transform(image).unsqueeze(0)
If I use model = lvvit_s(img_size=384), it loads the official model, but how to load my finetune model ?

@zihangJiang
Copy link
Owner

If the number of classes is not 1000, you should also pass num_classes to the model (i.e. model = lvvit_s(img_size=384, num_classes=2))

@Williamlizl
Copy link
Author

If the number of classes is not 1000, you should also pass num_classes to the model (i.e. model = lvvit_s(img_size=384, num_classes=2))

It does work, thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants