Skip to content

Commit

Permalink
Simplify and make file extension handling more robust (#3155)
Browse files Browse the repository at this point in the history
  • Loading branch information
hackdna committed Jan 10, 2019
1 parent 936a619 commit 29ffb14
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 49 deletions.
37 changes: 16 additions & 21 deletions refinery/file_store/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ def save(self, *args, **kwargs):
if not self.filetype:
# set file type using file extension
try:
extension = self.get_file_extension()
except RuntimeError as exc:
extension = _get_file_extension(self.get_extension())
except FileExtension.DoesNotExist as exc:
logger.warn("Could not assign type to file '%s': %s",
self, exc)
except FileExtension.MultipleObjectsReturned as exc:
logger.critical("Could not assign type to file '%s': %s",
self, exc)
else:
self.filetype = extension.filetype

Expand All @@ -148,24 +151,6 @@ def get_file_size(self):
logger.critical("Error getting size for '%s': %s", self, exc)
return 0

def get_file_extension(self):
"""Return FileExtension object based on datafile name or source"""
extension = self.get_extension()
try:
return FileExtension.objects.get(name=extension)
except FileExtension.DoesNotExist:
extension = _get_extension_from_string(extension)
try:
return FileExtension.objects.get(name=extension)
except FileExtension.DoesNotExist as exc:
raise RuntimeError(
"Extension '{}' is not valid: {}".format(extension, exc)
)
except FileExtension.MultipleObjectsReturned as exc:
raise RuntimeError(exc)
except FileExtension.MultipleObjectsReturned as exc:
raise RuntimeError(exc)

def get_extension(self):
"""Return extension of datafile name or file name in source"""
if self.datafile.name:
Expand All @@ -174,7 +159,7 @@ def get_extension(self):
return _get_extension_from_string(self.source)

def delete_datafile(self, save_instance=True):
"""Delete datafile on disk and cancel file import"""
"""Delete file from disk or S3 bucket and cancel file import"""
self.terminate_file_import_task()
if self.datafile:
file_name = self.datafile.name
Expand Down Expand Up @@ -293,3 +278,13 @@ def _get_extension_from_string(path):
if len(file_name_parts) > 2: # two or more periods in file name
return '.'.join(file_name_parts[-2:])
return file_name_parts[-1] # one period in file name


def _get_file_extension(extension):
"""Return FileExtension object for a given file name or extension string"""
try:
return FileExtension.objects.get(name=extension)
except FileExtension.DoesNotExist:
if not extension:
raise
return _get_file_extension('.'.join(extension.split('.')[1:]))
72 changes: 44 additions & 28 deletions refinery/file_store/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from override_storage import override_storage

from .models import (FileExtension, FileStoreItem, FileType,
_get_extension_from_string, _map_source,
generate_file_source_translator, get_temp_dir)
_get_extension_from_string, _get_file_extension,
_map_source, generate_file_source_translator,
get_temp_dir)


class FileStoreModuleTest(TestCase):
Expand Down Expand Up @@ -109,27 +110,6 @@ def test_set_remote_file_type_with_multiple_period_file_name(self):
saved_item = FileStoreItem.objects.get(pk=new_item.pk)
self.assertEqual(saved_item.filetype, self.file_type)

def test_get_remote_file_extension(self):
item = FileStoreItem(source=self.url_source)
self.assertEqual(item.get_file_extension(), self.file_extension)

def test_get_remote_file_multi_extension(self):
# TODO: replace with create() when migrations are no longer required
file_type = FileType.objects.get_or_create(name='FASTQ.GZ')[0]
file_extension = FileExtension.objects.get_or_create(
name='fastq.gz', filetype=file_type
)[0]
item = FileStoreItem(source='http://example.org/test.fastq.gz')
self.assertEqual(item.get_file_extension(), file_extension)

def test_get_remote_file_extension_with_multiple_period_file_name(self):
item = FileStoreItem(source='http://example.org/test.name.tdf')
self.assertEqual(item.get_file_extension(), self.file_extension)

def test_get_invalid_remote_file_extension(self):
item = FileStoreItem(source='http://example.org/test.name.invalid')
self.assertRaises(RuntimeError, item.get_file_extension)

def test_file_source_map_translation(self):
with override_settings(
REFINERY_FILE_SOURCE_MAP={
Expand Down Expand Up @@ -191,11 +171,6 @@ def test_set_local_file_type_update(self):
saved_item = FileStoreItem.objects.get(pk=self.item.pk)
self.assertEqual(saved_item.filetype, zip_file_type)

def test_get_local_file_extension(self):
self.item.datafile.save(self.file_name, ContentFile(''))
saved_item = FileStoreItem.objects.get(pk=self.item.pk)
self.assertEqual(saved_item.get_file_extension(), self.file_extension)

def test_delete_local_file_on_instance_delete(self):
self.item.datafile.save(self.file_name, ContentFile(''))
with mock.patch.object(FieldFile, 'path'):
Expand Down Expand Up @@ -307,3 +282,44 @@ def test_no_terminate_on_save_with_no_new_datafile(self):
) as mock_terminate_task:
self.item.save()
mock_terminate_task.assert_not_called()


class FileExtensionTest(TestCase):

def setUp(self):
# TODO: replace with create() when migrations are no longer required
self.fastq_extension = FileExtension.objects.get_or_create(
name='fastq',
filetype=FileType.objects.get_or_create(name='FASTQ')[0]
)[0]
self.gz_extension = FileExtension.objects.get_or_create(
name='gz', filetype=FileType.objects.get_or_create(name='GZ')[0]
)[0]
self.fastq_gz_extension = FileExtension.objects.get_or_create(
name='fastq.gz',
filetype=FileType.objects.get_or_create(name='FASTQ.GZ')[0]
)[0]

def test_get_existing_extension(self):
self.assertEqual(_get_file_extension('fastq'), self.fastq_extension)
self.assertEqual(_get_file_extension('random.fastq'),
self.fastq_extension)
self.assertEqual(_get_file_extension('gz'), self.gz_extension)
self.assertEqual(_get_file_extension('random.gz'), self.gz_extension)

def test_get_existing_multi_extension(self):
self.assertEqual(_get_file_extension('fastq.gz'),
self.fastq_gz_extension)
self.assertEqual(_get_file_extension('random.fastq.gz'),
self.fastq_gz_extension)

def test_get_blank_extension(self):
self.assertRaises(FileExtension.DoesNotExist, _get_file_extension, '')

def test_get_non_existing_extension(self):
self.assertRaises(FileExtension.DoesNotExist, _get_file_extension,
'invalid')

def test_get_non_existing_multi_extension(self):
self.assertRaises(FileExtension.DoesNotExist, _get_file_extension,
'invalid.extension')

0 comments on commit 29ffb14

Please sign in to comment.