Skip to content

Commit

Permalink
Import CNN junk predictions and add them to Zooniverse metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
adammcmaster committed Apr 12, 2024
1 parent b758d09 commit c748c5b
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 4 deletions.
3 changes: 2 additions & 1 deletion starcatalogue/admin.py
Expand Up @@ -86,12 +86,13 @@ class FoldedLightcurveAdmin(admin.ModelAdmin):
"period_length",
"sigma",
"chi_squared",
"cnn_junk_prediction",
)
search_fields = ("star__superwasp_id",)
fields = (
("star", "zooniversesubject", "created"),
("period_number",),
("period_length", "sigma", "chi_squared"),
("period_length", "sigma", "chi_squared", "cnn_junk_prediction"),
("updated_period_length", "updated_sigma", "updated_chi_squared"),
("image_file", "thumbnail_file", "image_version"),
)
Expand Down
38 changes: 38 additions & 0 deletions starcatalogue/management/commands/create_all_zoo_subjects.py
@@ -0,0 +1,38 @@
from django.core.management.base import BaseCommand, CommandError

import csv

from starcatalogue.models import Star, FoldedLightcurve, ZooniverseSubject


class Command(BaseCommand):
help = "Imports folded lightcurve data (lookup.dat)"

def add_arguments(self, parser):
parser.add_argument("file", nargs=1, type=open)

def handle(self, *args, **options):
r = csv.reader(options["file"][0], delimiter=" ", skipinitialspace=True)
imported_total = 0
for count, row in enumerate(r):
try:
subject_id = int(row[0])
superwasp_id = row[1]
period_number = int(row[3])
except IndexError:
print("Warning: Skipping row {} due to IndexError".format(count))
continue

star = Star.objects.get_or_create(superwasp_id=superwasp_id)
lightcurve = FoldedLightcurve.objects.get_or_create(
star=star,
period_number=period_number,
)

ZooniverseSubject.objects.get_or_create(
zooniverse_id=subject_id, lightcurve=lightcurve
)

imported_total += 1

self.stdout.write("Total imported: {}".format(imported_total))
40 changes: 40 additions & 0 deletions starcatalogue/management/commands/importjunkpredictions.py
@@ -0,0 +1,40 @@
from django.core.management.base import BaseCommand, CommandError

import csv
import pandas

from starcatalogue.models import FoldedLightcurve


class Command(BaseCommand):
help = "Imports CNN junk predictions from a DataFrame into FoldedLightcurves"

def add_arguments(self, parser):
parser.add_argument("file", nargs=1, type=str)

def handle(self, *args, **options):
df = pandas.read_pickle(options["file"][0])
imported_total = 0
lcs = []
total = len(df)
updated = 0
for n, (i, row) in enumerate(df.iterrows(), start=1):
superwasp_id, period_number, _ = i.replace(".gif", "").split("_")
period_number = int(period_number.replace("P", ""))
try:
lcs.append(
FoldedLightcurve.objects.get(
star__superwasp_id=superwasp_id, period_number=period_number
)
)
lcs[-1].cnn_junk_prediction = row["prediction"]
except FoldedLightcurve.DoesNotExist:
pass
if n % 100 == 0 or n == total:
print(f"\r{n} / {total} ({(n / total) * 100:.2f}%)", end="")
if len(lcs) > 0:
updated += FoldedLightcurve.objects.bulk_update(
lcs, ["cnn_junk_prediction"]
)
lcs = []
print(f"\nUpdated {updated}")
@@ -0,0 +1,18 @@
# Generated by Django 4.2.11 on 2024-04-12 10:06

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('starcatalogue', '0044_alter_dataexport_doi'),
]

operations = [
migrations.AddField(
model_name='foldedlightcurve',
name='cnn_junk_prediction',
field=models.FloatField(null=True),
),
]
5 changes: 4 additions & 1 deletion starcatalogue/models.py
Expand Up @@ -329,6 +329,8 @@ class FoldedLightcurve(models.Model, ImageGenerator):
updated_sigma = models.FloatField(null=True)
updated_chi_squared = models.FloatField(null=True)

cnn_junk_prediction = models.FloatField(null=True)

image_file = models.ImageField(null=True, upload_to=lightcurve_upload_to)
thumbnail_file = models.ImageField(null=True, upload_to=lightcurve_upload_to)
images_celery_task_id = models.UUIDField(null=True)
Expand Down Expand Up @@ -398,7 +400,7 @@ def timeseries(self):


class ZooniverseSubject(models.Model):
CURRENT_METADATA_VERSION = 1.0
CURRENT_METADATA_VERSION = 2.0

zooniverse_id = models.IntegerField(unique=True)
lightcurve = models.OneToOneField(to=FoldedLightcurve, on_delete=models.CASCADE)
Expand Down Expand Up @@ -433,6 +435,7 @@ def subject_metadata(self):
"!Simbad": f"http://simbad.u-strasbg.fr/simbad/sim-coo?Coord={ra}+{dec}&Radius=2&Radius.unit=arcmin&submit=submit+query",
"!ASAS-SN Photometry": f"https://asas-sn.osu.edu/photometry?ra={ra}&dec={dec}&radius=2",
"!VeSPA": f"https://{settings.ALLOWED_HOSTS[0]}{self.lightcurve.get_period_url()}",
"Junk Prediction": self.lightcurve.cnn_junk_prediction,
}

@property
Expand Down
4 changes: 4 additions & 0 deletions starcatalogue/tasks.py
Expand Up @@ -206,6 +206,10 @@ def generate_star_images(star_id):
@shared_task
def save_zooniverse_metadata(vespa_subject_id):
vespa_subject = ZooniverseSubject.objects.get(id=vespa_subject_id)

if vespa_subject.lightcurve.cnn_junk_prediction is None:
return

zoo_subject = Subject.find(vespa_subject.zooniverse_id)

zoo_subject.metadata = vespa_subject.subject_metadata
Expand Down
8 changes: 6 additions & 2 deletions vespa/celery.py
Expand Up @@ -125,8 +125,12 @@ def set_locations():
def set_zooniverse_metadata():
from starcatalogue.models import ZooniverseSubject

for subject in ZooniverseSubject.objects.filter(
for subject in ZooniverseSubject.objects.exclude(
lightcurve__cnn_junk_prediction=None
).filter(
Q(metadata_version=None)
| Q(metadata_version__lt=ZooniverseSubject.CURRENT_METADATA_VERSION)
)[:1000]:
)[
:10000
]:
subject.save_metadata()

0 comments on commit c748c5b

Please sign in to comment.